feat: copy tensors (#382)

Adds `Tensor::to` to allow copying tensors to another device, and impls
`Clone` for tensors.
This commit is contained in:
decahedron1
2025-04-13 15:16:37 -05:00
committed by GitHub
parent af2a03b9f3
commit 8c1c9baacb
10 changed files with 261 additions and 10 deletions

1
.gitignore vendored
View File

@@ -6,6 +6,7 @@ Cargo.lock
**/*.ort
tools/train-data/**/checkpoint
!tests/data/**/*
!src/value/impl_tensor/identity.ort
.venv*
docs/.next

View File

@@ -121,7 +121,7 @@ impl IoBinding {
///
/// The data is only copied upon invocation of this function. Any changes to the given value afterwards will not
/// affect the data seen by the session until the value is re-bound.
pub fn bind_input<T: ValueTypeMarker, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
pub fn bind_input<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
let name: String = name.into();
let ptr = self.ptr_mut();
with_cstr(name.as_bytes(), &|name| {
@@ -142,7 +142,7 @@ impl IoBinding {
/// `name`.
///
/// [`Tensor::new`]: crate::value::Tensor::new
pub fn bind_output<T: ValueTypeMarker, S: Into<String>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
pub fn bind_output<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
let name: String = name.into();
let ptr = self.ptr_mut();
with_cstr(name.as_bytes(), &|name| {

View File

@@ -244,7 +244,7 @@ impl Drop for AllocatedBlock<'_> {
}
/// Represents possible devices that have their own device allocator.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
// &'static str should be valid here since they're only ever defined in C++ with `const char *` literals
pub struct AllocationDevice(&'static str);

View File

@@ -112,7 +112,9 @@ impl SessionBuilder {
let model_path = crate::util::path_to_os_char(model_filepath);
let env = get_environment()?;
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
if !self.no_env_eps {
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
}
if env.has_global_threadpool && !self.no_global_thread_pool {
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
@@ -185,7 +187,9 @@ impl SessionBuilder {
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
let env = get_environment()?;
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
if !self.no_env_eps {
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
}
if env.has_global_threadpool && !self.no_global_thread_pool {
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];

View File

@@ -215,6 +215,11 @@ impl SessionBuilder {
Ok(self)
}
pub fn with_no_environment_execution_providers(mut self) -> Result<Self> {
self.no_env_eps = true;
Ok(self)
}
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
let manager = Rc::new(manager);
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];

View File

@@ -38,7 +38,8 @@ pub struct SessionBuilder {
external_initializer_buffers: Vec<Cow<'static, [u8]>>,
prepacked_weights: Option<PrepackedWeights>,
thread_manager: Option<Rc<dyn Any>>,
no_global_thread_pool: bool
no_global_thread_pool: bool,
no_env_eps: bool
}
impl Clone for SessionBuilder {
@@ -57,7 +58,8 @@ impl Clone for SessionBuilder {
external_initializer_buffers: self.external_initializer_buffers.clone(),
prepacked_weights: self.prepacked_weights.clone(),
thread_manager: self.thread_manager.clone(),
no_global_thread_pool: self.no_global_thread_pool
no_global_thread_pool: self.no_global_thread_pool,
no_env_eps: self.no_env_eps
}
}
}
@@ -93,7 +95,8 @@ impl SessionBuilder {
external_initializer_buffers: Vec::new(),
prepacked_weights: None,
thread_manager: None,
no_global_thread_pool: false
no_global_thread_pool: false,
no_env_eps: false
})
}

View File

@@ -298,6 +298,102 @@ impl<T> Drop for OnceLock<T> {
}
}
pub use self::mutex::{Mutex, MutexGuard};
#[cfg(feature = "std")]
mod mutex {
use std::sync::Mutex as StdMutex;
pub use std::sync::MutexGuard;
#[repr(transparent)]
pub struct Mutex<T>(StdMutex<T>);
impl<T> Mutex<T> {
pub fn new(data: T) -> Self {
Self(StdMutex::new(data))
}
pub fn lock(&self) -> MutexGuard<'_, T> {
match self.0.lock() {
Ok(guard) => guard,
// ignore poison error
Err(p) => p.into_inner()
}
}
}
}
#[cfg(not(feature = "std"))]
mod mutex {
use core::{
cell::UnsafeCell,
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, Ordering}
};
pub struct Mutex<T> {
is_locked: AtomicBool,
data: UnsafeCell<T>
}
unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Send> Sync for Mutex<T> {}
impl<T> Mutex<T> {
pub fn new(data: T) -> Self {
Mutex {
is_locked: AtomicBool::new(false),
data: UnsafeCell::new(data)
}
}
pub fn lock(&self) -> MutexGuard<'_, T> {
loop {
if self
.is_locked
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return MutexGuard {
is_locked: &self.is_locked,
data: unsafe { &mut *self.data.get() }
};
}
while self.is_locked.load(Ordering::Relaxed) {
core::hint::spin_loop();
}
}
}
}
pub struct MutexGuard<'a, T: 'a> {
is_locked: &'a AtomicBool,
data: *mut T
}
unsafe impl<T: Send> Send for MutexGuard<'_, T> {}
unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.data }
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.data }
}
}
impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
self.is_locked.store(false, Ordering::Release);
}
}
}
#[cold]
#[inline]
#[doc(hidden)]

View File

@@ -246,6 +246,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
_ => unreachable!()
}
}
pub fn data_type(&self) -> &TensorElementType {
match self.dtype() {
ValueType::Tensor { ty, .. } => ty,
_ => unreachable!()
}
}
}
fn extract_tensor<'t>(

Binary file not shown.

View File

@@ -1,7 +1,7 @@
mod create;
mod extract;
use alloc::sync::Arc;
use alloc::{format, string::ToString, sync::Arc};
use core::{
fmt::{self, Debug},
marker::PhantomData,
@@ -14,7 +14,7 @@ use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefM
use crate::{
AsPointer,
error::Result,
memory::{Allocator, MemoryInfo},
memory::{AllocationDevice, Allocator, MemoryInfo},
ortsys,
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
};
@@ -216,6 +216,141 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
pub fn memory_info(&self) -> &MemoryInfo {
unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() }
}
/// Copies the contents of this tensor to another device, returning the newly created tensor value.
///
/// ```
/// # use ort::{memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType}, session::Session, value::Tensor};
/// # fn main() -> ort::Result<()> {
/// # if false {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let cuda_allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
/// )?;
/// let cuda_tensor = Tensor::<f32>::new(&cuda_allocator, [1_usize, 3, 224, 224])?;
/// # }
/// # let cuda_tensor = Tensor::<f32>::new(&Allocator::default(), [1_usize, 3, 224, 224])?;
///
/// let cpu_tensor = cuda_tensor.to(AllocationDevice::CPU, 0)?;
/// assert_eq!(cpu_tensor.memory_info().allocation_device(), AllocationDevice::CPU);
/// assert_eq!(**cpu_tensor.shape(), [1, 3, 224, 224]);
/// # Ok(())
/// # }
/// ```
pub fn to(&self, device: AllocationDevice, device_id: i32) -> Result<Value<Type>> {
use crate::{
OnceLock, execution_providers as ep,
io_binding::IoBinding,
memory::{AllocatorType, MemoryType},
session::{Session, builder::GraphOptimizationLevel},
util::{MiniMap, Mutex}
};
type IdentitySessionKey = (AllocationDevice, i32, ort_sys::ONNXTensorElementDataType);
type IdentitySession = (Session, IoBinding);
static SESSIONS: OnceLock<Mutex<MiniMap<IdentitySessionKey, IdentitySession>>> = OnceLock::new();
static IDENTITY_MODEL: &[u8] = include_bytes!("./identity.ort");
let target_memory_info = MemoryInfo::new(device, device_id, AllocatorType::Device, MemoryType::Default)?;
let tensor_type = ort_sys::ONNXTensorElementDataType::from(*self.data_type());
let mut sessions = SESSIONS.get_or_init(|| Mutex::new(MiniMap::new())).lock();
let (session, binding) = match sessions.get_mut(&(device, device_id, tensor_type)) {
Some(entry) => entry,
None => {
let mut model_bytes = IDENTITY_MODEL.to_vec();
// Override the expected element type of the input & output nodes, respectively.
model_bytes[544] = tensor_type as u8;
model_bytes[604] = tensor_type as u8;
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Disable)?
.with_intra_threads(1)?
.with_inter_threads(1)?
.with_inter_op_spinning(false)?
.with_intra_op_spinning(false)?
.with_memory_pattern(false)?
.with_no_environment_execution_providers()?
.with_execution_providers([match device {
AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(),
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default()
.with_device_id(device_id)
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
.with_conv_max_workspace(false)
.with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default)
.build(),
AllocationDevice::DIRECTML | AllocationDevice::DIRECTML_CPU => {
ep::DirectMLExecutionProvider::default().with_device_id(device_id).build()
}
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default()
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
.with_cann_graph(false)
.with_device_id(device_id)
.build(),
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default()
.with_num_threads(1)
.with_device_type(if device == AllocationDevice::OPENVINO_CPU {
"CPU".to_string()
} else {
format!("GPU.{device_id}")
})
.build(),
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCmExecutionProvider::default()
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
.with_hip_graph(false)
.with_exhaustive_conv_search(false)
.with_device_id(device_id)
.build(),
AllocationDevice::TVM => ep::TVMExecutionProvider::default().build(),
AllocationDevice::XNNPACK => ep::XNNPACKExecutionProvider::default().build(),
_ => return Err(crate::Error::new("Unsupported allocation device {device} for tensor copy target"))
}])?
.with_allocator(MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?)?
.commit_from_memory(&model_bytes)?;
let binding = session.create_binding()?;
sessions.insert((device, device_id, tensor_type), (session, binding));
sessions.get_mut(&(device, device_id, tensor_type)).expect("insert should have worked")
}
};
binding.bind_input("input", self)?;
binding.bind_output_to_device("output", &target_memory_info)?;
let output = session
.run_binding(binding)?
.remove("output")
.expect("identity model should have single output");
Ok(unsafe { output.transmute_type() })
}
}
impl<Type: TensorValueTypeMarker + ?Sized> Clone for Value<Type> {
/// Creates a copy of this tensor and its data on the same device it resides on.
///
/// ```
/// # use ort::{value::Tensor, AsPointer};
/// # fn main() -> ort::Result<()> {
/// let array = vec![1_i64, 2, 3, 4, 5];
/// let tensor = Tensor::from_array(([array.len()], array.into_boxed_slice()))?;
///
/// let new_tensor = tensor.clone();
///
/// // same data
/// assert_eq!(tensor.extract_tensor(), new_tensor.extract_tensor());
///
/// // different allocations
/// assert_ne!(tensor.ptr(), new_tensor.ptr());
/// assert_ne!(tensor.data_ptr()?, new_tensor.data_ptr()?);
/// # Ok(())
/// # }
/// ```
fn clone(&self) -> Self {
let memory_info = self.memory_info();
self.to(memory_info.allocation_device(), memory_info.device_id())
.expect("Failed to clone tensor")
}
}
impl<T: IntoTensorElementType + Debug> Tensor<T> {