mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: copy tensors (#382)
Adds `Tensor::to` to allow copying tensors to another device, and impls `Clone` for tensors.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@ Cargo.lock
|
||||
**/*.ort
|
||||
tools/train-data/**/checkpoint
|
||||
!tests/data/**/*
|
||||
!src/value/impl_tensor/identity.ort
|
||||
|
||||
.venv*
|
||||
docs/.next
|
||||
|
||||
@@ -121,7 +121,7 @@ impl IoBinding {
|
||||
///
|
||||
/// The data is only copied upon invocation of this function. Any changes to the given value afterwards will not
|
||||
/// affect the data seen by the session until the value is re-bound.
|
||||
pub fn bind_input<T: ValueTypeMarker, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
|
||||
pub fn bind_input<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
|
||||
let name: String = name.into();
|
||||
let ptr = self.ptr_mut();
|
||||
with_cstr(name.as_bytes(), &|name| {
|
||||
@@ -142,7 +142,7 @@ impl IoBinding {
|
||||
/// `name`.
|
||||
///
|
||||
/// [`Tensor::new`]: crate::value::Tensor::new
|
||||
pub fn bind_output<T: ValueTypeMarker, S: Into<String>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
|
||||
pub fn bind_output<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
|
||||
let name: String = name.into();
|
||||
let ptr = self.ptr_mut();
|
||||
with_cstr(name.as_bytes(), &|name| {
|
||||
|
||||
@@ -244,7 +244,7 @@ impl Drop for AllocatedBlock<'_> {
|
||||
}
|
||||
|
||||
/// Represents possible devices that have their own device allocator.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
// &'static str should be valid here since they're only ever defined in C++ with `const char *` literals
|
||||
pub struct AllocationDevice(&'static str);
|
||||
|
||||
|
||||
@@ -112,7 +112,9 @@ impl SessionBuilder {
|
||||
let model_path = crate::util::path_to_os_char(model_filepath);
|
||||
|
||||
let env = get_environment()?;
|
||||
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
||||
if !self.no_env_eps {
|
||||
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
||||
}
|
||||
|
||||
if env.has_global_threadpool && !self.no_global_thread_pool {
|
||||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
||||
@@ -185,7 +187,9 @@ impl SessionBuilder {
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
|
||||
|
||||
let env = get_environment()?;
|
||||
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
||||
if !self.no_env_eps {
|
||||
apply_execution_providers(&mut self, &env.execution_providers, "environment")?;
|
||||
}
|
||||
|
||||
if env.has_global_threadpool && !self.no_global_thread_pool {
|
||||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
||||
|
||||
@@ -215,6 +215,11 @@ impl SessionBuilder {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_no_environment_execution_providers(mut self) -> Result<Self> {
|
||||
self.no_env_eps = true;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
||||
let manager = Rc::new(manager);
|
||||
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
||||
|
||||
@@ -38,7 +38,8 @@ pub struct SessionBuilder {
|
||||
external_initializer_buffers: Vec<Cow<'static, [u8]>>,
|
||||
prepacked_weights: Option<PrepackedWeights>,
|
||||
thread_manager: Option<Rc<dyn Any>>,
|
||||
no_global_thread_pool: bool
|
||||
no_global_thread_pool: bool,
|
||||
no_env_eps: bool
|
||||
}
|
||||
|
||||
impl Clone for SessionBuilder {
|
||||
@@ -57,7 +58,8 @@ impl Clone for SessionBuilder {
|
||||
external_initializer_buffers: self.external_initializer_buffers.clone(),
|
||||
prepacked_weights: self.prepacked_weights.clone(),
|
||||
thread_manager: self.thread_manager.clone(),
|
||||
no_global_thread_pool: self.no_global_thread_pool
|
||||
no_global_thread_pool: self.no_global_thread_pool,
|
||||
no_env_eps: self.no_env_eps
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,7 +95,8 @@ impl SessionBuilder {
|
||||
external_initializer_buffers: Vec::new(),
|
||||
prepacked_weights: None,
|
||||
thread_manager: None,
|
||||
no_global_thread_pool: false
|
||||
no_global_thread_pool: false,
|
||||
no_env_eps: false
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
96
src/util.rs
96
src/util.rs
@@ -298,6 +298,102 @@ impl<T> Drop for OnceLock<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub use self::mutex::{Mutex, MutexGuard};
|
||||
#[cfg(feature = "std")]
|
||||
mod mutex {
|
||||
use std::sync::Mutex as StdMutex;
|
||||
pub use std::sync::MutexGuard;
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Mutex<T>(StdMutex<T>);
|
||||
|
||||
impl<T> Mutex<T> {
|
||||
pub fn new(data: T) -> Self {
|
||||
Self(StdMutex::new(data))
|
||||
}
|
||||
|
||||
pub fn lock(&self) -> MutexGuard<'_, T> {
|
||||
match self.0.lock() {
|
||||
Ok(guard) => guard,
|
||||
// ignore poison error
|
||||
Err(p) => p.into_inner()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "std"))]
|
||||
mod mutex {
|
||||
use core::{
|
||||
cell::UnsafeCell,
|
||||
ops::{Deref, DerefMut},
|
||||
sync::atomic::{AtomicBool, Ordering}
|
||||
};
|
||||
|
||||
pub struct Mutex<T> {
|
||||
is_locked: AtomicBool,
|
||||
data: UnsafeCell<T>
|
||||
}
|
||||
|
||||
unsafe impl<T: Send> Send for Mutex<T> {}
|
||||
unsafe impl<T: Send> Sync for Mutex<T> {}
|
||||
|
||||
impl<T> Mutex<T> {
|
||||
pub fn new(data: T) -> Self {
|
||||
Mutex {
|
||||
is_locked: AtomicBool::new(false),
|
||||
data: UnsafeCell::new(data)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lock(&self) -> MutexGuard<'_, T> {
|
||||
loop {
|
||||
if self
|
||||
.is_locked
|
||||
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
|
||||
.is_ok()
|
||||
{
|
||||
return MutexGuard {
|
||||
is_locked: &self.is_locked,
|
||||
data: unsafe { &mut *self.data.get() }
|
||||
};
|
||||
}
|
||||
|
||||
while self.is_locked.load(Ordering::Relaxed) {
|
||||
core::hint::spin_loop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MutexGuard<'a, T: 'a> {
|
||||
is_locked: &'a AtomicBool,
|
||||
data: *mut T
|
||||
}
|
||||
|
||||
unsafe impl<T: Send> Send for MutexGuard<'_, T> {}
|
||||
unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
|
||||
|
||||
impl<T> Deref for MutexGuard<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe { &*self.data }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for MutexGuard<'_, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
unsafe { &mut *self.data }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for MutexGuard<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
self.is_locked.store(false, Ordering::Release);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cold]
|
||||
#[inline]
|
||||
#[doc(hidden)]
|
||||
|
||||
@@ -246,6 +246,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
_ => unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn data_type(&self) -> &TensorElementType {
|
||||
match self.dtype() {
|
||||
ValueType::Tensor { ty, .. } => ty,
|
||||
_ => unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_tensor<'t>(
|
||||
|
||||
BIN
src/value/impl_tensor/identity.ort
Normal file
BIN
src/value/impl_tensor/identity.ort
Normal file
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
mod create;
|
||||
mod extract;
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use alloc::{format, string::ToString, sync::Arc};
|
||||
use core::{
|
||||
fmt::{self, Debug},
|
||||
marker::PhantomData,
|
||||
@@ -14,7 +14,7 @@ use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefM
|
||||
use crate::{
|
||||
AsPointer,
|
||||
error::Result,
|
||||
memory::{Allocator, MemoryInfo},
|
||||
memory::{AllocationDevice, Allocator, MemoryInfo},
|
||||
ortsys,
|
||||
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
|
||||
};
|
||||
@@ -216,6 +216,141 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
pub fn memory_info(&self) -> &MemoryInfo {
|
||||
unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() }
|
||||
}
|
||||
|
||||
/// Copies the contents of this tensor to another device, returning the newly created tensor value.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType}, session::Session, value::Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # if false {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let cuda_allocator = Allocator::new(
|
||||
/// &session,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
|
||||
/// )?;
|
||||
/// let cuda_tensor = Tensor::<f32>::new(&cuda_allocator, [1_usize, 3, 224, 224])?;
|
||||
/// # }
|
||||
/// # let cuda_tensor = Tensor::<f32>::new(&Allocator::default(), [1_usize, 3, 224, 224])?;
|
||||
///
|
||||
/// let cpu_tensor = cuda_tensor.to(AllocationDevice::CPU, 0)?;
|
||||
/// assert_eq!(cpu_tensor.memory_info().allocation_device(), AllocationDevice::CPU);
|
||||
/// assert_eq!(**cpu_tensor.shape(), [1, 3, 224, 224]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn to(&self, device: AllocationDevice, device_id: i32) -> Result<Value<Type>> {
|
||||
use crate::{
|
||||
OnceLock, execution_providers as ep,
|
||||
io_binding::IoBinding,
|
||||
memory::{AllocatorType, MemoryType},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
util::{MiniMap, Mutex}
|
||||
};
|
||||
|
||||
type IdentitySessionKey = (AllocationDevice, i32, ort_sys::ONNXTensorElementDataType);
|
||||
type IdentitySession = (Session, IoBinding);
|
||||
|
||||
static SESSIONS: OnceLock<Mutex<MiniMap<IdentitySessionKey, IdentitySession>>> = OnceLock::new();
|
||||
static IDENTITY_MODEL: &[u8] = include_bytes!("./identity.ort");
|
||||
|
||||
let target_memory_info = MemoryInfo::new(device, device_id, AllocatorType::Device, MemoryType::Default)?;
|
||||
let tensor_type = ort_sys::ONNXTensorElementDataType::from(*self.data_type());
|
||||
|
||||
let mut sessions = SESSIONS.get_or_init(|| Mutex::new(MiniMap::new())).lock();
|
||||
let (session, binding) = match sessions.get_mut(&(device, device_id, tensor_type)) {
|
||||
Some(entry) => entry,
|
||||
None => {
|
||||
let mut model_bytes = IDENTITY_MODEL.to_vec();
|
||||
// Override the expected element type of the input & output nodes, respectively.
|
||||
model_bytes[544] = tensor_type as u8;
|
||||
model_bytes[604] = tensor_type as u8;
|
||||
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Disable)?
|
||||
.with_intra_threads(1)?
|
||||
.with_inter_threads(1)?
|
||||
.with_inter_op_spinning(false)?
|
||||
.with_intra_op_spinning(false)?
|
||||
.with_memory_pattern(false)?
|
||||
.with_no_environment_execution_providers()?
|
||||
.with_execution_providers([match device {
|
||||
AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(),
|
||||
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id)
|
||||
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||
.with_conv_max_workspace(false)
|
||||
.with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default)
|
||||
.build(),
|
||||
AllocationDevice::DIRECTML | AllocationDevice::DIRECTML_CPU => {
|
||||
ep::DirectMLExecutionProvider::default().with_device_id(device_id).build()
|
||||
}
|
||||
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default()
|
||||
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||
.with_cann_graph(false)
|
||||
.with_device_id(device_id)
|
||||
.build(),
|
||||
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default()
|
||||
.with_num_threads(1)
|
||||
.with_device_type(if device == AllocationDevice::OPENVINO_CPU {
|
||||
"CPU".to_string()
|
||||
} else {
|
||||
format!("GPU.{device_id}")
|
||||
})
|
||||
.build(),
|
||||
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCmExecutionProvider::default()
|
||||
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||
.with_hip_graph(false)
|
||||
.with_exhaustive_conv_search(false)
|
||||
.with_device_id(device_id)
|
||||
.build(),
|
||||
AllocationDevice::TVM => ep::TVMExecutionProvider::default().build(),
|
||||
AllocationDevice::XNNPACK => ep::XNNPACKExecutionProvider::default().build(),
|
||||
_ => return Err(crate::Error::new("Unsupported allocation device {device} for tensor copy target"))
|
||||
}])?
|
||||
.with_allocator(MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?)?
|
||||
.commit_from_memory(&model_bytes)?;
|
||||
let binding = session.create_binding()?;
|
||||
sessions.insert((device, device_id, tensor_type), (session, binding));
|
||||
sessions.get_mut(&(device, device_id, tensor_type)).expect("insert should have worked")
|
||||
}
|
||||
};
|
||||
|
||||
binding.bind_input("input", self)?;
|
||||
binding.bind_output_to_device("output", &target_memory_info)?;
|
||||
|
||||
let output = session
|
||||
.run_binding(binding)?
|
||||
.remove("output")
|
||||
.expect("identity model should have single output");
|
||||
Ok(unsafe { output.transmute_type() })
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type: TensorValueTypeMarker + ?Sized> Clone for Value<Type> {
|
||||
/// Creates a copy of this tensor and its data on the same device it resides on.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{value::Tensor, AsPointer};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let array = vec![1_i64, 2, 3, 4, 5];
|
||||
/// let tensor = Tensor::from_array(([array.len()], array.into_boxed_slice()))?;
|
||||
///
|
||||
/// let new_tensor = tensor.clone();
|
||||
///
|
||||
/// // same data
|
||||
/// assert_eq!(tensor.extract_tensor(), new_tensor.extract_tensor());
|
||||
///
|
||||
/// // different allocations
|
||||
/// assert_ne!(tensor.ptr(), new_tensor.ptr());
|
||||
/// assert_ne!(tensor.data_ptr()?, new_tensor.data_ptr()?);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn clone(&self) -> Self {
|
||||
let memory_info = self.memory_info();
|
||||
self.to(memory_info.allocation_device(), memory_info.device_id())
|
||||
.expect("Failed to clone tensor")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
|
||||
Reference in New Issue
Block a user