mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: copy tensors (#382)
Adds `Tensor::to` to allow copying tensors to another device, and impls `Clone` for tensors.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@ Cargo.lock
|
|||||||
**/*.ort
|
**/*.ort
|
||||||
tools/train-data/**/checkpoint
|
tools/train-data/**/checkpoint
|
||||||
!tests/data/**/*
|
!tests/data/**/*
|
||||||
|
!src/value/impl_tensor/identity.ort
|
||||||
|
|
||||||
.venv*
|
.venv*
|
||||||
docs/.next
|
docs/.next
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ impl IoBinding {
|
|||||||
///
|
///
|
||||||
/// The data is only copied upon invocation of this function. Any changes to the given value afterwards will not
|
/// The data is only copied upon invocation of this function. Any changes to the given value afterwards will not
|
||||||
/// affect the data seen by the session until the value is re-bound.
|
/// affect the data seen by the session until the value is re-bound.
|
||||||
pub fn bind_input<T: ValueTypeMarker, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
|
pub fn bind_input<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
|
||||||
let name: String = name.into();
|
let name: String = name.into();
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
with_cstr(name.as_bytes(), &|name| {
|
with_cstr(name.as_bytes(), &|name| {
|
||||||
@@ -142,7 +142,7 @@ impl IoBinding {
|
|||||||
/// `name`.
|
/// `name`.
|
||||||
///
|
///
|
||||||
/// [`Tensor::new`]: crate::value::Tensor::new
|
/// [`Tensor::new`]: crate::value::Tensor::new
|
||||||
pub fn bind_output<T: ValueTypeMarker, S: Into<String>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
|
pub fn bind_output<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
|
||||||
let name: String = name.into();
|
let name: String = name.into();
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
with_cstr(name.as_bytes(), &|name| {
|
with_cstr(name.as_bytes(), &|name| {
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ impl Drop for AllocatedBlock<'_> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Represents possible devices that have their own device allocator.
|
/// Represents possible devices that have their own device allocator.
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
// &'static str should be valid here since they're only ever defined in C++ with `const char *` literals
|
// &'static str should be valid here since they're only ever defined in C++ with `const char *` literals
|
||||||
pub struct AllocationDevice(&'static str);
|
pub struct AllocationDevice(&'static str);
|
||||||
|
|
||||||
|
|||||||
@@ -112,7 +112,9 @@ impl SessionBuilder {
|
|||||||
let model_path = crate::util::path_to_os_char(model_filepath);
|
let model_path = crate::util::path_to_os_char(model_filepath);
|
||||||
|
|
||||||
let env = get_environment()?;
|
let env = get_environment()?;
|
||||||
|
if !self.no_env_eps {
|
||||||
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
||||||
|
}
|
||||||
|
|
||||||
if env.has_global_threadpool && !self.no_global_thread_pool {
|
if env.has_global_threadpool && !self.no_global_thread_pool {
|
||||||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
||||||
@@ -185,7 +187,9 @@ impl SessionBuilder {
|
|||||||
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
|
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
|
||||||
|
|
||||||
let env = get_environment()?;
|
let env = get_environment()?;
|
||||||
|
if !self.no_env_eps {
|
||||||
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
||||||
|
}
|
||||||
|
|
||||||
if env.has_global_threadpool && !self.no_global_thread_pool {
|
if env.has_global_threadpool && !self.no_global_thread_pool {
|
||||||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
||||||
|
|||||||
@@ -215,6 +215,11 @@ impl SessionBuilder {
|
|||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_no_environment_execution_providers(mut self) -> Result<Self> {
|
||||||
|
self.no_env_eps = true;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
||||||
let manager = Rc::new(manager);
|
let manager = Rc::new(manager);
|
||||||
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ pub struct SessionBuilder {
|
|||||||
external_initializer_buffers: Vec<Cow<'static, [u8]>>,
|
external_initializer_buffers: Vec<Cow<'static, [u8]>>,
|
||||||
prepacked_weights: Option<PrepackedWeights>,
|
prepacked_weights: Option<PrepackedWeights>,
|
||||||
thread_manager: Option<Rc<dyn Any>>,
|
thread_manager: Option<Rc<dyn Any>>,
|
||||||
no_global_thread_pool: bool
|
no_global_thread_pool: bool,
|
||||||
|
no_env_eps: bool
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for SessionBuilder {
|
impl Clone for SessionBuilder {
|
||||||
@@ -57,7 +58,8 @@ impl Clone for SessionBuilder {
|
|||||||
external_initializer_buffers: self.external_initializer_buffers.clone(),
|
external_initializer_buffers: self.external_initializer_buffers.clone(),
|
||||||
prepacked_weights: self.prepacked_weights.clone(),
|
prepacked_weights: self.prepacked_weights.clone(),
|
||||||
thread_manager: self.thread_manager.clone(),
|
thread_manager: self.thread_manager.clone(),
|
||||||
no_global_thread_pool: self.no_global_thread_pool
|
no_global_thread_pool: self.no_global_thread_pool,
|
||||||
|
no_env_eps: self.no_env_eps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,7 +95,8 @@ impl SessionBuilder {
|
|||||||
external_initializer_buffers: Vec::new(),
|
external_initializer_buffers: Vec::new(),
|
||||||
prepacked_weights: None,
|
prepacked_weights: None,
|
||||||
thread_manager: None,
|
thread_manager: None,
|
||||||
no_global_thread_pool: false
|
no_global_thread_pool: false,
|
||||||
|
no_env_eps: false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
96
src/util.rs
96
src/util.rs
@@ -298,6 +298,102 @@ impl<T> Drop for OnceLock<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub use self::mutex::{Mutex, MutexGuard};
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod mutex {
|
||||||
|
use std::sync::Mutex as StdMutex;
|
||||||
|
pub use std::sync::MutexGuard;
|
||||||
|
|
||||||
|
#[repr(transparent)]
|
||||||
|
pub struct Mutex<T>(StdMutex<T>);
|
||||||
|
|
||||||
|
impl<T> Mutex<T> {
|
||||||
|
pub fn new(data: T) -> Self {
|
||||||
|
Self(StdMutex::new(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn lock(&self) -> MutexGuard<'_, T> {
|
||||||
|
match self.0.lock() {
|
||||||
|
Ok(guard) => guard,
|
||||||
|
// ignore poison error
|
||||||
|
Err(p) => p.into_inner()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
mod mutex {
|
||||||
|
use core::{
|
||||||
|
cell::UnsafeCell,
|
||||||
|
ops::{Deref, DerefMut},
|
||||||
|
sync::atomic::{AtomicBool, Ordering}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct Mutex<T> {
|
||||||
|
is_locked: AtomicBool,
|
||||||
|
data: UnsafeCell<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<T: Send> Send for Mutex<T> {}
|
||||||
|
unsafe impl<T: Send> Sync for Mutex<T> {}
|
||||||
|
|
||||||
|
impl<T> Mutex<T> {
|
||||||
|
pub fn new(data: T) -> Self {
|
||||||
|
Mutex {
|
||||||
|
is_locked: AtomicBool::new(false),
|
||||||
|
data: UnsafeCell::new(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn lock(&self) -> MutexGuard<'_, T> {
|
||||||
|
loop {
|
||||||
|
if self
|
||||||
|
.is_locked
|
||||||
|
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
|
return MutexGuard {
|
||||||
|
is_locked: &self.is_locked,
|
||||||
|
data: unsafe { &mut *self.data.get() }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
while self.is_locked.load(Ordering::Relaxed) {
|
||||||
|
core::hint::spin_loop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MutexGuard<'a, T: 'a> {
|
||||||
|
is_locked: &'a AtomicBool,
|
||||||
|
data: *mut T
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<T: Send> Send for MutexGuard<'_, T> {}
|
||||||
|
unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
|
||||||
|
|
||||||
|
impl<T> Deref for MutexGuard<'_, T> {
|
||||||
|
type Target = T;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
unsafe { &*self.data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> DerefMut for MutexGuard<'_, T> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
unsafe { &mut *self.data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Drop for MutexGuard<'_, T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.is_locked.store(false, Ordering::Release);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cold]
|
#[cold]
|
||||||
#[inline]
|
#[inline]
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
|
|||||||
@@ -246,6 +246,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
|||||||
_ => unreachable!()
|
_ => unreachable!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn data_type(&self) -> &TensorElementType {
|
||||||
|
match self.dtype() {
|
||||||
|
ValueType::Tensor { ty, .. } => ty,
|
||||||
|
_ => unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_tensor<'t>(
|
fn extract_tensor<'t>(
|
||||||
|
|||||||
BIN
src/value/impl_tensor/identity.ort
Normal file
BIN
src/value/impl_tensor/identity.ort
Normal file
Binary file not shown.
@@ -1,7 +1,7 @@
|
|||||||
mod create;
|
mod create;
|
||||||
mod extract;
|
mod extract;
|
||||||
|
|
||||||
use alloc::sync::Arc;
|
use alloc::{format, string::ToString, sync::Arc};
|
||||||
use core::{
|
use core::{
|
||||||
fmt::{self, Debug},
|
fmt::{self, Debug},
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
@@ -14,7 +14,7 @@ use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefM
|
|||||||
use crate::{
|
use crate::{
|
||||||
AsPointer,
|
AsPointer,
|
||||||
error::Result,
|
error::Result,
|
||||||
memory::{Allocator, MemoryInfo},
|
memory::{AllocationDevice, Allocator, MemoryInfo},
|
||||||
ortsys,
|
ortsys,
|
||||||
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
|
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
|
||||||
};
|
};
|
||||||
@@ -216,6 +216,141 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
|||||||
pub fn memory_info(&self) -> &MemoryInfo {
|
pub fn memory_info(&self) -> &MemoryInfo {
|
||||||
unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() }
|
unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Copies the contents of this tensor to another device, returning the newly created tensor value.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use ort::{memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType}, session::Session, value::Tensor};
|
||||||
|
/// # fn main() -> ort::Result<()> {
|
||||||
|
/// # if false {
|
||||||
|
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||||
|
/// let cuda_allocator = Allocator::new(
|
||||||
|
/// &session,
|
||||||
|
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
|
||||||
|
/// )?;
|
||||||
|
/// let cuda_tensor = Tensor::<f32>::new(&cuda_allocator, [1_usize, 3, 224, 224])?;
|
||||||
|
/// # }
|
||||||
|
/// # let cuda_tensor = Tensor::<f32>::new(&Allocator::default(), [1_usize, 3, 224, 224])?;
|
||||||
|
///
|
||||||
|
/// let cpu_tensor = cuda_tensor.to(AllocationDevice::CPU, 0)?;
|
||||||
|
/// assert_eq!(cpu_tensor.memory_info().allocation_device(), AllocationDevice::CPU);
|
||||||
|
/// assert_eq!(**cpu_tensor.shape(), [1, 3, 224, 224]);
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
pub fn to(&self, device: AllocationDevice, device_id: i32) -> Result<Value<Type>> {
|
||||||
|
use crate::{
|
||||||
|
OnceLock, execution_providers as ep,
|
||||||
|
io_binding::IoBinding,
|
||||||
|
memory::{AllocatorType, MemoryType},
|
||||||
|
session::{Session, builder::GraphOptimizationLevel},
|
||||||
|
util::{MiniMap, Mutex}
|
||||||
|
};
|
||||||
|
|
||||||
|
type IdentitySessionKey = (AllocationDevice, i32, ort_sys::ONNXTensorElementDataType);
|
||||||
|
type IdentitySession = (Session, IoBinding);
|
||||||
|
|
||||||
|
static SESSIONS: OnceLock<Mutex<MiniMap<IdentitySessionKey, IdentitySession>>> = OnceLock::new();
|
||||||
|
static IDENTITY_MODEL: &[u8] = include_bytes!("./identity.ort");
|
||||||
|
|
||||||
|
let target_memory_info = MemoryInfo::new(device, device_id, AllocatorType::Device, MemoryType::Default)?;
|
||||||
|
let tensor_type = ort_sys::ONNXTensorElementDataType::from(*self.data_type());
|
||||||
|
|
||||||
|
let mut sessions = SESSIONS.get_or_init(|| Mutex::new(MiniMap::new())).lock();
|
||||||
|
let (session, binding) = match sessions.get_mut(&(device, device_id, tensor_type)) {
|
||||||
|
Some(entry) => entry,
|
||||||
|
None => {
|
||||||
|
let mut model_bytes = IDENTITY_MODEL.to_vec();
|
||||||
|
// Override the expected element type of the input & output nodes, respectively.
|
||||||
|
model_bytes[544] = tensor_type as u8;
|
||||||
|
model_bytes[604] = tensor_type as u8;
|
||||||
|
|
||||||
|
let session = Session::builder()?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Disable)?
|
||||||
|
.with_intra_threads(1)?
|
||||||
|
.with_inter_threads(1)?
|
||||||
|
.with_inter_op_spinning(false)?
|
||||||
|
.with_intra_op_spinning(false)?
|
||||||
|
.with_memory_pattern(false)?
|
||||||
|
.with_no_environment_execution_providers()?
|
||||||
|
.with_execution_providers([match device {
|
||||||
|
AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(),
|
||||||
|
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default()
|
||||||
|
.with_device_id(device_id)
|
||||||
|
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||||
|
.with_conv_max_workspace(false)
|
||||||
|
.with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default)
|
||||||
|
.build(),
|
||||||
|
AllocationDevice::DIRECTML | AllocationDevice::DIRECTML_CPU => {
|
||||||
|
ep::DirectMLExecutionProvider::default().with_device_id(device_id).build()
|
||||||
|
}
|
||||||
|
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default()
|
||||||
|
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||||
|
.with_cann_graph(false)
|
||||||
|
.with_device_id(device_id)
|
||||||
|
.build(),
|
||||||
|
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default()
|
||||||
|
.with_num_threads(1)
|
||||||
|
.with_device_type(if device == AllocationDevice::OPENVINO_CPU {
|
||||||
|
"CPU".to_string()
|
||||||
|
} else {
|
||||||
|
format!("GPU.{device_id}")
|
||||||
|
})
|
||||||
|
.build(),
|
||||||
|
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCmExecutionProvider::default()
|
||||||
|
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||||
|
.with_hip_graph(false)
|
||||||
|
.with_exhaustive_conv_search(false)
|
||||||
|
.with_device_id(device_id)
|
||||||
|
.build(),
|
||||||
|
AllocationDevice::TVM => ep::TVMExecutionProvider::default().build(),
|
||||||
|
AllocationDevice::XNNPACK => ep::XNNPACKExecutionProvider::default().build(),
|
||||||
|
_ => return Err(crate::Error::new("Unsupported allocation device {device} for tensor copy target"))
|
||||||
|
}])?
|
||||||
|
.with_allocator(MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?)?
|
||||||
|
.commit_from_memory(&model_bytes)?;
|
||||||
|
let binding = session.create_binding()?;
|
||||||
|
sessions.insert((device, device_id, tensor_type), (session, binding));
|
||||||
|
sessions.get_mut(&(device, device_id, tensor_type)).expect("insert should have worked")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
binding.bind_input("input", self)?;
|
||||||
|
binding.bind_output_to_device("output", &target_memory_info)?;
|
||||||
|
|
||||||
|
let output = session
|
||||||
|
.run_binding(binding)?
|
||||||
|
.remove("output")
|
||||||
|
.expect("identity model should have single output");
|
||||||
|
Ok(unsafe { output.transmute_type() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Type: TensorValueTypeMarker + ?Sized> Clone for Value<Type> {
|
||||||
|
/// Creates a copy of this tensor and its data on the same device it resides on.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use ort::{value::Tensor, AsPointer};
|
||||||
|
/// # fn main() -> ort::Result<()> {
|
||||||
|
/// let array = vec![1_i64, 2, 3, 4, 5];
|
||||||
|
/// let tensor = Tensor::from_array(([array.len()], array.into_boxed_slice()))?;
|
||||||
|
///
|
||||||
|
/// let new_tensor = tensor.clone();
|
||||||
|
///
|
||||||
|
/// // same data
|
||||||
|
/// assert_eq!(tensor.extract_tensor(), new_tensor.extract_tensor());
|
||||||
|
///
|
||||||
|
/// // different allocations
|
||||||
|
/// assert_ne!(tensor.ptr(), new_tensor.ptr());
|
||||||
|
/// assert_ne!(tensor.data_ptr()?, new_tensor.data_ptr()?);
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
let memory_info = self.memory_info();
|
||||||
|
self.to(memory_info.allocation_device(), memory_info.device_id())
|
||||||
|
.expect("Failed to clone tensor")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||||
|
|||||||
Reference in New Issue
Block a user