diff --git a/.gitignore b/.gitignore index f9ace1d..f584c1a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ Cargo.lock **/*.ort tools/train-data/**/checkpoint !tests/data/**/* +!src/value/impl_tensor/identity.ort .venv* docs/.next diff --git a/src/io_binding.rs b/src/io_binding.rs index 42ddb51..e442cda 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -121,7 +121,7 @@ impl IoBinding { /// /// The data is only copied upon invocation of this function. Any changes to the given value afterwards will not /// affect the data seen by the session until the value is re-bound. - pub fn bind_input>(&mut self, name: S, ort_value: &Value) -> Result<()> { + pub fn bind_input>(&mut self, name: S, ort_value: &Value) -> Result<()> { let name: String = name.into(); let ptr = self.ptr_mut(); with_cstr(name.as_bytes(), &|name| { @@ -142,7 +142,7 @@ impl IoBinding { /// `name`. /// /// [`Tensor::new`]: crate::value::Tensor::new - pub fn bind_output>(&mut self, name: S, ort_value: Value) -> Result<()> { + pub fn bind_output>(&mut self, name: S, ort_value: Value) -> Result<()> { let name: String = name.into(); let ptr = self.ptr_mut(); with_cstr(name.as_bytes(), &|name| { diff --git a/src/memory.rs b/src/memory.rs index e638eba..bdec220 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -244,7 +244,7 @@ impl Drop for AllocatedBlock<'_> { } /// Represents possible devices that have their own device allocator. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] // &'static str should be valid here since they're only ever defined in C++ with `const char *` literals pub struct AllocationDevice(&'static str); diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs index d24341e..6ba328e 100644 --- a/src/session/builder/impl_commit.rs +++ b/src/session/builder/impl_commit.rs @@ -112,7 +112,9 @@ impl SessionBuilder { let model_path = crate::util::path_to_os_char(model_filepath); let env = get_environment()?; - apply_execution_providers(&mut self, &env.execution_providers, "environment")?; + if !self.no_env_eps { + apply_execution_providers(&mut self, &env.execution_providers, "environment")?; + } if env.has_global_threadpool && !self.no_global_thread_pool { ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; @@ -185,7 +187,9 @@ impl SessionBuilder { let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut(); let env = get_environment()?; - apply_execution_providers(&mut self, &env.execution_providers, "environment")?; + if !self.no_env_eps { + apply_execution_providers(&mut self, &env.execution_providers, "environment")?; + } if env.has_global_threadpool && !self.no_global_thread_pool { ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; diff --git a/src/session/builder/impl_options.rs b/src/session/builder/impl_options.rs index 00d7c18..7107913 100644 --- a/src/session/builder/impl_options.rs +++ b/src/session/builder/impl_options.rs @@ -215,6 +215,11 @@ impl SessionBuilder { Ok(self) } + pub fn with_no_environment_execution_providers(mut self) -> Result { + self.no_env_eps = true; + Ok(self) + } + pub fn with_thread_manager(mut self, manager: T) -> Result { let manager = Rc::new(manager); ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?]; diff --git a/src/session/builder/mod.rs b/src/session/builder/mod.rs index 199ce6e..5fa5944 100644 --- a/src/session/builder/mod.rs +++ b/src/session/builder/mod.rs @@ -38,7 +38,8 @@ pub struct SessionBuilder { external_initializer_buffers: Vec>, prepacked_weights: Option, thread_manager: Option>, - no_global_thread_pool: bool + no_global_thread_pool: bool, + no_env_eps: bool } impl Clone for SessionBuilder { @@ -57,7 +58,8 @@ impl Clone for SessionBuilder { external_initializer_buffers: self.external_initializer_buffers.clone(), prepacked_weights: self.prepacked_weights.clone(), thread_manager: self.thread_manager.clone(), - no_global_thread_pool: self.no_global_thread_pool + no_global_thread_pool: self.no_global_thread_pool, + no_env_eps: self.no_env_eps } } } @@ -93,7 +95,8 @@ impl SessionBuilder { external_initializer_buffers: Vec::new(), prepacked_weights: None, thread_manager: None, - no_global_thread_pool: false + no_global_thread_pool: false, + no_env_eps: false }) } diff --git a/src/util.rs b/src/util.rs index bde35fa..0d5fc2b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -298,6 +298,102 @@ impl Drop for OnceLock { } } +pub use self::mutex::{Mutex, MutexGuard}; +#[cfg(feature = "std")] +mod mutex { + use std::sync::Mutex as StdMutex; + pub use std::sync::MutexGuard; + + #[repr(transparent)] + pub struct Mutex(StdMutex); + + impl Mutex { + pub fn new(data: T) -> Self { + Self(StdMutex::new(data)) + } + + pub fn lock(&self) -> MutexGuard<'_, T> { + match self.0.lock() { + Ok(guard) => guard, + // ignore poison error + Err(p) => p.into_inner() + } + } + } +} +#[cfg(not(feature = "std"))] +mod mutex { + use core::{ + cell::UnsafeCell, + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, Ordering} + }; + + pub struct Mutex { + is_locked: AtomicBool, + data: UnsafeCell + } + + unsafe impl Send for Mutex {} + unsafe impl Sync for Mutex {} + + impl Mutex { + pub fn new(data: T) -> Self { + Mutex { + is_locked: AtomicBool::new(false), + data: UnsafeCell::new(data) + } + } + + pub fn lock(&self) -> MutexGuard<'_, T> { + loop { + if self + .is_locked + .compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + return MutexGuard { + is_locked: &self.is_locked, + data: unsafe { &mut *self.data.get() } + }; + } + + while self.is_locked.load(Ordering::Relaxed) { + core::hint::spin_loop(); + } + } + } + } + + pub struct MutexGuard<'a, T: 'a> { + is_locked: &'a AtomicBool, + data: *mut T + } + + unsafe impl Send for MutexGuard<'_, T> {} + unsafe impl Sync for MutexGuard<'_, T> {} + + impl Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.data } + } + } + + impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.data } + } + } + + impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + self.is_locked.store(false, Ordering::Release); + } + } +} + #[cold] #[inline] #[doc(hidden)] diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index d275b4d..7af3870 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -246,6 +246,13 @@ impl Value { _ => unreachable!() } } + + pub fn data_type(&self) -> &TensorElementType { + match self.dtype() { + ValueType::Tensor { ty, .. } => ty, + _ => unreachable!() + } + } } fn extract_tensor<'t>( diff --git a/src/value/impl_tensor/identity.ort b/src/value/impl_tensor/identity.ort new file mode 100644 index 0000000..ebd5bca Binary files /dev/null and b/src/value/impl_tensor/identity.ort differ diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index 795bf2d..d6397f5 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -1,7 +1,7 @@ mod create; mod extract; -use alloc::sync::Arc; +use alloc::{format, string::ToString, sync::Arc}; use core::{ fmt::{self, Debug}, marker::PhantomData, @@ -14,7 +14,7 @@ use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefM use crate::{ AsPointer, error::Result, - memory::{Allocator, MemoryInfo}, + memory::{AllocationDevice, Allocator, MemoryInfo}, ortsys, tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType} }; @@ -216,6 +216,141 @@ impl Value { pub fn memory_info(&self) -> &MemoryInfo { unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() } } + + /// Copies the contents of this tensor to another device, returning the newly created tensor value. + /// + /// ``` + /// # use ort::{memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType}, session::Session, value::Tensor}; + /// # fn main() -> ort::Result<()> { + /// # if false { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let cuda_allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// let cuda_tensor = Tensor::::new(&cuda_allocator, [1_usize, 3, 224, 224])?; + /// # } + /// # let cuda_tensor = Tensor::::new(&Allocator::default(), [1_usize, 3, 224, 224])?; + /// + /// let cpu_tensor = cuda_tensor.to(AllocationDevice::CPU, 0)?; + /// assert_eq!(cpu_tensor.memory_info().allocation_device(), AllocationDevice::CPU); + /// assert_eq!(**cpu_tensor.shape(), [1, 3, 224, 224]); + /// # Ok(()) + /// # } + /// ``` + pub fn to(&self, device: AllocationDevice, device_id: i32) -> Result> { + use crate::{ + OnceLock, execution_providers as ep, + io_binding::IoBinding, + memory::{AllocatorType, MemoryType}, + session::{Session, builder::GraphOptimizationLevel}, + util::{MiniMap, Mutex} + }; + + type IdentitySessionKey = (AllocationDevice, i32, ort_sys::ONNXTensorElementDataType); + type IdentitySession = (Session, IoBinding); + + static SESSIONS: OnceLock>> = OnceLock::new(); + static IDENTITY_MODEL: &[u8] = include_bytes!("./identity.ort"); + + let target_memory_info = MemoryInfo::new(device, device_id, AllocatorType::Device, MemoryType::Default)?; + let tensor_type = ort_sys::ONNXTensorElementDataType::from(*self.data_type()); + + let mut sessions = SESSIONS.get_or_init(|| Mutex::new(MiniMap::new())).lock(); + let (session, binding) = match sessions.get_mut(&(device, device_id, tensor_type)) { + Some(entry) => entry, + None => { + let mut model_bytes = IDENTITY_MODEL.to_vec(); + // Override the expected element type of the input & output nodes, respectively. + model_bytes[544] = tensor_type as u8; + model_bytes[604] = tensor_type as u8; + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Disable)? + .with_intra_threads(1)? + .with_inter_threads(1)? + .with_inter_op_spinning(false)? + .with_intra_op_spinning(false)? + .with_memory_pattern(false)? + .with_no_environment_execution_providers()? + .with_execution_providers([match device { + AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(), + AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default() + .with_device_id(device_id) + .with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested) + .with_conv_max_workspace(false) + .with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default) + .build(), + AllocationDevice::DIRECTML | AllocationDevice::DIRECTML_CPU => { + ep::DirectMLExecutionProvider::default().with_device_id(device_id).build() + } + AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default() + .with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested) + .with_cann_graph(false) + .with_device_id(device_id) + .build(), + AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default() + .with_num_threads(1) + .with_device_type(if device == AllocationDevice::OPENVINO_CPU { + "CPU".to_string() + } else { + format!("GPU.{device_id}") + }) + .build(), + AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCmExecutionProvider::default() + .with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested) + .with_hip_graph(false) + .with_exhaustive_conv_search(false) + .with_device_id(device_id) + .build(), + AllocationDevice::TVM => ep::TVMExecutionProvider::default().build(), + AllocationDevice::XNNPACK => ep::XNNPACKExecutionProvider::default().build(), + _ => return Err(crate::Error::new("Unsupported allocation device {device} for tensor copy target")) + }])? + .with_allocator(MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?)? + .commit_from_memory(&model_bytes)?; + let binding = session.create_binding()?; + sessions.insert((device, device_id, tensor_type), (session, binding)); + sessions.get_mut(&(device, device_id, tensor_type)).expect("insert should have worked") + } + }; + + binding.bind_input("input", self)?; + binding.bind_output_to_device("output", &target_memory_info)?; + + let output = session + .run_binding(binding)? + .remove("output") + .expect("identity model should have single output"); + Ok(unsafe { output.transmute_type() }) + } +} + +impl Clone for Value { + /// Creates a copy of this tensor and its data on the same device it resides on. + /// + /// ``` + /// # use ort::{value::Tensor, AsPointer}; + /// # fn main() -> ort::Result<()> { + /// let array = vec![1_i64, 2, 3, 4, 5]; + /// let tensor = Tensor::from_array(([array.len()], array.into_boxed_slice()))?; + /// + /// let new_tensor = tensor.clone(); + /// + /// // same data + /// assert_eq!(tensor.extract_tensor(), new_tensor.extract_tensor()); + /// + /// // different allocations + /// assert_ne!(tensor.ptr(), new_tensor.ptr()); + /// assert_ne!(tensor.data_ptr()?, new_tensor.data_ptr()?); + /// # Ok(()) + /// # } + /// ``` + fn clone(&self) -> Self { + let memory_info = self.memory_info(); + self.to(memory_info.allocation_device(), memory_info.device_id()) + .expect("Failed to clone tensor") + } } impl Tensor {