diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index d21a848..46af77a 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -49,7 +49,7 @@ fn main() -> ort::Result<()> { stdout.flush().unwrap(); for _ in 0..GEN_TOKENS { - // Raw tensor construction takes a tuple of (dimensions, data). + // Raw tensor construction takes a tuple of (shape, data). // The model expects our input to have shape [B, _, S] let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; let outputs = session.run(inputs![input])?; diff --git a/src/lib.rs b/src/lib.rs index 62edd4c..de9aa43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,7 @@ pub mod tensor; #[cfg(feature = "training")] #[cfg_attr(docsrs, doc(cfg(feature = "training")))] pub mod training; +#[doc(hidden)] pub mod util; pub mod value; diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 8177c11..d9617be 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -6,14 +6,13 @@ use core::{ slice }; -use smallvec::SmallVec; - use crate::{ AsPointer, error::{Error, Result}, memory::{Allocator, MemoryInfo, MemoryType}, ortsys, session::{Input, Output}, + tensor::Shape, util::with_cstr, value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType} }; @@ -350,9 +349,9 @@ impl KernelContext { Ok(NonNull::new(value_ptr.cast_mut()).map(|c| ValueRef::new(unsafe { Value::from_ptr_nodrop(c, None) }))) } - pub fn output(&self, idx: usize, shape: impl IntoIterator) -> Result>> { + pub fn output(&self, idx: usize, shape: impl Into) -> Result>> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - let shape = shape.into_iter().collect::>(); + let shape = shape.into(); ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx, shape.as_ptr(), shape.len(), &mut value_ptr)?]; Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) }))) } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index e179f51..f484763 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -1,9 +1,126 @@ -//! Traits related to [`Tensor`](crate::value::Tensor)s. +//! Traits and types related to [`Tensor`](crate::value::Tensor)s. #[cfg(feature = "ndarray")] mod ndarray; mod types; +use alloc::{string::String, vec::Vec}; +use core::{ + fmt, + ops::{Deref, DerefMut} +}; + +use smallvec::{SmallVec, smallvec}; + #[cfg(feature = "ndarray")] pub use self::ndarray::ArrayExtensions; pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; + +#[derive(Default, Clone, PartialEq, Eq)] +pub struct Shape { + inner: SmallVec +} + +impl Shape { + pub fn new(dims: impl IntoIterator) -> Self { + Self { inner: dims.into_iter().collect() } + } + + pub fn empty(rank: usize) -> Self { + Self { inner: smallvec![0; rank] } + } + + #[doc(alias = "numel")] + pub fn num_elements(&self) -> usize { + let mut size = 1usize; + for dim in &self.inner { + if *dim < 0 { + return 0; + } + size *= *dim as usize; + } + size + } + + #[cfg(feature = "ndarray")] + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + pub fn to_ixdyn(&self) -> ::ndarray::IxDyn { + use ::ndarray::IntoDimension; + self.inner.iter().map(|d| *d as usize).collect::>().into_dimension() + } +} + +impl fmt::Debug for Shape { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.inner.iter()).finish() + } +} + +impl fmt::Display for Shape { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.inner.iter()).finish() + } +} + +impl From> for Shape { + fn from(value: Vec) -> Self { + Self { inner: SmallVec::from(value) } + } +} + +impl From<&[i64]> for Shape { + fn from(value: &[i64]) -> Self { + Self { inner: SmallVec::from(value) } + } +} + +impl From<[i64; N]> for Shape { + fn from(value: [i64; N]) -> Self { + Self { inner: SmallVec::from(value) } + } +} + +impl FromIterator for Shape { + fn from_iter>(iter: T) -> Self { + Self { inner: iter.into_iter().collect() } + } +} + +impl Deref for Shape { + type Target = [i64]; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for Shape { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SymbolicDimensions(SmallVec); + +impl SymbolicDimensions { + pub fn new(dims: impl IntoIterator) -> Self { + Self(dims.into_iter().collect()) + } + + pub fn empty(rank: usize) -> Self { + Self(smallvec![String::default(); rank]) + } +} + +impl FromIterator for SymbolicDimensions { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl Deref for SymbolicDimensions { + type Target = [String]; + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/util.rs b/src/util.rs index 6fd90cf..bde35fa 100644 --- a/src/util.rs +++ b/src/util.rs @@ -303,17 +303,6 @@ impl Drop for OnceLock { #[doc(hidden)] pub fn cold() {} -pub fn element_count(shape: &[i64]) -> usize { - let mut size = 1usize; - for dim in shape { - if *dim < 0 { - return 0; - } - size *= *dim as usize; - } - size -} - #[inline] pub(crate) fn with_cstr(bytes: &[u8], f: &dyn Fn(&CStr) -> Result) -> Result { fn run_with_heap_cstr(bytes: &[u8], f: &dyn Fn(&CStr) -> Result) -> Result { diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 1dccab2..3c7dc32 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -20,8 +20,7 @@ use crate::{ error::{Error, Result}, memory::Allocator, ortsys, - tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType}, - util::element_count + tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType} }; pub trait MapValueTypeMarker: ValueTypeMarker { @@ -109,7 +108,7 @@ impl Value { if K::into_tensor_element_type() != TensorElementType::String { let dtype = key_value.dtype(); let (key_tensor_shape, key_tensor) = match dtype { - ValueType::Tensor { ty, dimensions, .. } => { + ValueType::Tensor { ty, shape, .. } => { let mem = key_value.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!( @@ -124,8 +123,7 @@ impl Value { let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = element_count(dimensions); - (dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) }) + (shape, unsafe { slice::from_raw_parts(output_array_ptr, shape.num_elements()) }) } else { return Err(Error::new_with_code( ErrorCode::InvalidArgument, diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 464cc25..5c94b1e 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -1,4 +1,4 @@ -use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec, vec::Vec}; +use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec::Vec}; use core::{ any::Any, ffi::c_void, @@ -17,8 +17,7 @@ use crate::{ error::{Error, ErrorCode, Result}, memory::{Allocator, MemoryInfo}, ortsys, - tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data}, - util::element_count, + tensor::{PrimitiveTensorElementType, Shape, SymbolicDimensions, TensorElementType, Utf8Data}, value::{Value, ValueInner, ValueType} }; @@ -78,8 +77,8 @@ impl Tensor { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, dtype: ValueType::Tensor { ty: TensorElementType::String, - dimensions: shape, - dimension_symbols: vec![String::default(); shape_len] + shape, + dimension_symbols: SymbolicDimensions::empty(shape_len) }, memory_info: MemoryInfo::from_value(value_ptr), drop: true, @@ -109,7 +108,7 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result> { + pub fn new(allocator: &Allocator, shape: impl Into) -> Result> { let tensor = DynTensor::new(allocator, T::into_tensor_element_type(), shape)?; Ok(unsafe { tensor.transmute_type() }) } @@ -141,17 +140,16 @@ impl Tensor { /// /// Creating string tensors requires a separate method; see [`Tensor::from_string_array`]. pub fn from_array(input: impl OwnedTensorArrayData) -> Result> { - let TensorArrayDataParts { shape, ptr, num_elements, guard } = input.into_parts()?; - tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), num_elements, size_of::(), T::into_tensor_element_type(), guard) + let TensorArrayDataParts { shape, ptr, guard } = input.into_parts()?; + tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), size_of::(), T::into_tensor_element_type(), guard) .map(|tensor| unsafe { tensor.transmute_type() }) } } fn tensor_from_array( memory_info: MemoryInfo, - shape: Vec, + shape: Shape, data: *mut c_void, - num_elements: usize, element_size: usize, element_type: TensorElementType, guard: Option> @@ -162,7 +160,7 @@ fn tensor_from_array( unsafe CreateTensorWithDataAsOrtValue( memory_info.ptr(), data, - num_elements * element_size, + shape.num_elements() * element_size, shape.as_ptr(), shape.len(), element_type.into(), @@ -176,8 +174,8 @@ fn tensor_from_array( ptr: unsafe { NonNull::new_unchecked(value_ptr) }, dtype: ValueType::Tensor { ty: element_type, - dimension_symbols: vec![String::default(); shape.len()], - dimensions: shape + dimension_symbols: SymbolicDimensions::empty(shape.len()), + shape }, drop: true, memory_info: Some(memory_info), @@ -219,15 +217,11 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view(input: impl TensorArrayData + 'a) -> Result> { let (shape, data, guard) = input.ref_parts()?; - let num_elements = element_count(&shape); - - tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::(), T::into_tensor_element_type(), guard).map( - |tensor| { - let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() }); - tensor.upgradable = false; - tensor - } - ) + tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::(), T::into_tensor_element_type(), guard).map(|tensor| { + let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() }); + tensor.upgradable = false; + tensor + }) } } @@ -262,21 +256,17 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view_mut(mut input: impl TensorArrayDataMut) -> Result> { let (shape, data, guard) = input.ref_parts_mut()?; - let num_elements = element_count(&shape); - - tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::(), T::into_tensor_element_type(), guard).map( - |tensor| { - let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() }); - tensor.upgradable = false; - tensor - } - ) + tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::(), T::into_tensor_element_type(), guard).map(|tensor| { + let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() }); + tensor.upgradable = false; + tensor + }) } /// Create a mutable tensor view from a raw pointer and shape. /// /// The length of data is determined by `T` and the given shape, so the given buffer must be at least - /// `shape.iter().product() * size_of::()` bytes. + /// `shape.num_elements() * size_of::()` bytes. /// /// This function can be used to create data from raw device memory, e.g. to directly provide data to an execution /// provider. For instance, to create a tensor from a raw CUDA buffer using [`cudarc`](https://docs.rs/cudarc): @@ -288,7 +278,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// TensorRefMut::from_raw( /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?, /// (*device_data.device_ptr() as usize as *mut ()).cast(), - /// vec![1, 3, 512, 512] + /// Shape::new([1, 3, 512, 512]) /// )? /// }; /// ``` @@ -296,9 +286,8 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// # Safety /// - The pointer must be valid for the device description provided by `MemoryInfo`. /// - The returned tensor must outlive the data described by the data pointer. - pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Vec) -> Result> { - let num_elements = element_count(&shape); - tensor_from_array(info, shape, data, num_elements, size_of::(), T::into_tensor_element_type(), None).map(|tensor| { + pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Shape) -> Result> { + tensor_from_array(info, shape, data, size_of::(), T::into_tensor_element_type(), None).map(|tensor| { let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() }); tensor.upgradable = false; tensor @@ -308,14 +297,14 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { pub trait TensorArrayData { #[allow(clippy::type_complexity)] - fn ref_parts(&self) -> Result<(Vec, &[I], Option>)>; + fn ref_parts(&self) -> Result<(Shape, &[I], Option>)>; private_trait!(); } pub trait TensorArrayDataMut: TensorArrayData { #[allow(clippy::type_complexity)] - fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [I], Option>)>; + fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [I], Option>)>; private_trait!(); } @@ -327,20 +316,19 @@ pub trait OwnedTensorArrayData { } pub struct TensorArrayDataParts { - pub shape: Vec, + pub shape: Shape, pub ptr: NonNull, - pub num_elements: usize, pub guard: Option> } pub trait ToDimensions { - fn to_dimensions(&self, expected_size: Option) -> Result>; + fn to_dimensions(&self, expected_size: Option) -> Result; } macro_rules! impl_to_dimensions { (@inner) => { - fn to_dimensions(&self, expected_size: Option) -> Result> { - let v: Vec = self + fn to_dimensions(&self, expected_size: Option) -> Result { + let v = self .iter() .enumerate() .map(|(i, c)| { @@ -353,13 +341,17 @@ macro_rules! impl_to_dimensions { )) } }) - .collect::>()?; - let sum = element_count(&v); + .collect::>()?; if let Some(expected_size) = expected_size { - if sum != expected_size { + if v.num_elements() != expected_size { Err(Error::new_with_code( ErrorCode::InvalidArgument, - format!("Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)", v, sum, expected_size) + format!( + "Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)", + v, + v.num_elements(), + expected_size + ) )) } else { Ok(v) @@ -382,21 +374,21 @@ macro_rules! impl_to_dimensions { } impl ToDimensions for () { - fn to_dimensions(&self, expected_size: Option) -> Result> { + fn to_dimensions(&self, expected_size: Option) -> Result { match expected_size { - Some(1) | None => Ok(vec![]), + Some(1) | None => Ok(Shape::default()), Some(_) => Err(Error::new_with_code(ErrorCode::InvalidArgument, "Expected data to have a length of exactly 1 for scalar shape")) } } } -impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec, for Vec, for Vec); +impl_to_dimensions!(for Shape, for &[usize], for &[i32], for &[i64], for Vec, for Vec, for Vec); impl_to_dimensions!( for [usize; N], for [i32; N], for [i64; N]); #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayData for &CowArray<'_, T, D> { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -409,8 +401,8 @@ impl TensorArrayData for &CowArra #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayData for ArcArray { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -423,8 +415,8 @@ impl TensorArrayData for ArcArray #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayData for &Array { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -437,8 +429,8 @@ impl TensorArrayData for &Array TensorArrayData for &mut Array { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -455,27 +447,21 @@ impl OwnedTensorArrayData for Arr if self.is_standard_layout() { // We can avoid the copy here and use the data as is let mut this = Box::new(self); - let shape: Vec = this.shape().iter().map(|d| *d as i64).collect(); + let shape: Shape = this.shape().iter().map(|d| *d as i64).collect(); // SAFETY: ndarrays internally store their pointer as NonNull let ptr = unsafe { NonNull::new_unchecked(this.as_mut_ptr()) }; - let num_elements = this.len(); - Ok(TensorArrayDataParts { - shape, - ptr, - num_elements, - guard: Some(this) - }) + assert_eq!(this.len(), shape.num_elements()); + Ok(TensorArrayDataParts { shape, ptr, guard: Some(this) }) } else { // Need to do a copy here to get data in to standard layout let mut contiguous_array = self.as_standard_layout().into_owned(); - let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); + let shape: Shape = contiguous_array.shape().iter().map(|d| *d as i64).collect(); // SAFETY: ndarrays internally store their pointer as NonNull let ptr = unsafe { NonNull::new_unchecked(contiguous_array.as_mut_ptr()) }; - let num_elements: usize = contiguous_array.len(); + assert_eq!(contiguous_array.len(), shape.num_elements()); Ok(TensorArrayDataParts { shape, ptr, - num_elements, guard: Some(Box::new(contiguous_array)) }) } @@ -487,8 +473,8 @@ impl OwnedTensorArrayData for Arr #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayData for ArrayView<'_, T, D> { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -501,8 +487,8 @@ impl TensorArrayData for ArrayVie #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayData for ArrayViewMut<'_, T, D> { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -515,8 +501,8 @@ impl TensorArrayData for ArrayVie #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayDataMut for ArrayViewMut<'_, T, D> { - fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice_mut() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -529,8 +515,8 @@ impl TensorArrayDataMut for Array #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] impl TensorArrayDataMut for &mut Array { - fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [T], Option>)> { - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option>)> { + let shape = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice_mut() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; @@ -541,7 +527,7 @@ impl TensorArrayDataMut for &mut } impl TensorArrayData for (D, &[T]) { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { let shape = self.0.to_dimensions(Some(self.1.len()))?; Ok((shape, self.1, None)) } @@ -550,7 +536,7 @@ impl TensorArrayData for (D, &[T]) { } impl TensorArrayData for (D, &mut [T]) { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { let shape = self.0.to_dimensions(Some(self.1.len()))?; Ok((shape, self.1, None)) } @@ -559,7 +545,7 @@ impl TensorArrayData for (D, &mut [T]) { } impl TensorArrayDataMut for (D, &mut [T]) { - fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [T], Option>)> { + fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option>)> { let shape = self.0.to_dimensions(Some(self.1.len()))?; Ok((shape, self.1, None)) } @@ -572,11 +558,10 @@ impl OwnedTensorArrayData for (D, Vec let shape = self.0.to_dimensions(Some(self.1.len()))?; // SAFETY: A `Vec` always has a non-null pointer. let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) }; - let num_elements: usize = self.1.len(); + assert_eq!(shape.num_elements(), self.1.len()); Ok(TensorArrayDataParts { shape, ptr, - num_elements, guard: Some(Box::new(self.1)) }) } @@ -589,11 +574,10 @@ impl OwnedTensorArrayData for (D, Box<[T let shape = self.0.to_dimensions(Some(self.1.len()))?; // SAFETY: A `Box` always has a non-null pointer. let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) }; - let num_elements: usize = self.1.len(); + assert_eq!(shape.num_elements(), self.1.len()); Ok(TensorArrayDataParts { shape, ptr, - num_elements, guard: Some(Box::new(self.1)) }) } @@ -602,7 +586,7 @@ impl OwnedTensorArrayData for (D, Box<[T } impl TensorArrayData for (D, Arc<[T]>) { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let data = &*self.1; Ok((shape, data, Some(Box::new(self.1.clone())))) @@ -612,7 +596,7 @@ impl TensorArrayData for (D, Arc<[T]>) { } impl TensorArrayData for (D, Arc>) { - fn ref_parts(&self) -> Result<(Vec, &[T], Option>)> { + fn ref_parts(&self) -> Result<(Shape, &[T], Option>)> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let data = &*self.1; Ok((shape, data, Some(Box::new(self.1.clone())))) diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 6d30788..8a34dea 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -12,8 +12,7 @@ use crate::{ error::{Error, ErrorCode, Result}, memory::MemoryInfo, ortsys, - tensor::{PrimitiveTensorElementType, TensorElementType}, - util::element_count, + tensor::{PrimitiveTensorElementType, Shape, TensorElementType}, value::{Value, ValueType} }; @@ -49,11 +48,8 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_tensor(&self) -> Result> { - use ndarray::IntoDimension; - extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { - let shape = dimensions.iter().map(|&n| n as usize).collect::>().into_dimension(); - Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape, data_ptr(ptr)?.cast::()) }) - }) + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()) + .and_then(|(ptr, shape)| Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape.to_ixdyn(), data_ptr(ptr)?.cast::()) })) } /// Attempt to extract the scalar from a tensor of type `T`. @@ -80,11 +76,11 @@ impl Value { /// /// [`DynValue`]: crate::value::DynValue pub fn try_extract_scalar(&self) -> Result { - extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { - if !dimensions.is_empty() { + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, shape)| { + if !shape.is_empty() { return Err(Error::new_with_code( ErrorCode::InvalidArgument, - format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), dimensions.len()) + format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), shape.len()) )); } @@ -122,14 +118,11 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_tensor_mut(&mut self) -> Result> { - use ndarray::IntoDimension; - extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { - let shape = dimensions.iter().map(|&n| n as usize).collect::>().into_dimension(); - Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape, data_ptr(ptr)?.cast::()) }) - }) + extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()) + .and_then(|(ptr, shape)| Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape.to_ixdyn(), data_ptr(ptr)?.cast::()) })) } - /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an + /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's shape and an /// immutable view into its data. /// /// See also: @@ -145,7 +138,7 @@ impl Value { /// /// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor::()?; /// assert_eq!(extracted_data, &array); - /// assert_eq!(extracted_shape, [5]); + /// assert_eq!(**extracted_shape, [5]); /// # Ok(()) /// # } /// ``` @@ -157,12 +150,12 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. /// /// [`DynValue`]: crate::value::DynValue - pub fn try_extract_raw_tensor(&self) -> Result<(&[i64], &[T])> { + pub fn try_extract_raw_tensor(&self) -> Result<(&Shape, &[T])> { extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()) - .and_then(|(ptr, dimensions)| Ok((dimensions, unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::(), element_count(dimensions)) }))) + .and_then(|(ptr, shape)| Ok((shape, unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::(), shape.num_elements()) }))) } - /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a + /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's shape and a /// mutable view into its data. /// /// See also the infallible counterpart, [`Tensor::extract_raw_tensor_mut`], for typed [`Tensor`]s. @@ -175,7 +168,7 @@ impl Value { /// /// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor_mut::()?; /// assert_eq!(extracted_data, &array); - /// assert_eq!(extracted_shape, [5]); + /// assert_eq!(**extracted_shape, [5]); /// # Ok(()) /// # } /// ``` @@ -187,9 +180,9 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. /// /// [`DynValue`]: crate::value::DynValue - pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(&[i64], &mut [T])> { + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(&Shape, &mut [T])> { extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()) - .and_then(|(ptr, dimensions)| Ok((dimensions, unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::(), element_count(dimensions)) }))) + .and_then(|(ptr, shape)| Ok((shape, unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::(), shape.num_elements()) }))) } /// Attempt to extract the underlying data into a Rust `ndarray`. @@ -208,15 +201,13 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - use ndarray::IntoDimension; - extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| { - let strings = extract_strings(ptr, dimensions)?; - Ok(ndarray::Array::from_shape_vec(dimensions.iter().map(|&n| n as usize).collect::>().into_dimension(), strings) - .expect("Shape extracted from tensor didn't match tensor contents")) + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, shape)| { + let strings = extract_strings(ptr, shape)?; + Ok(ndarray::Array::from_shape_vec(shape.to_ixdyn(), strings).expect("Shape extracted from tensor didn't match tensor contents")) }) } - /// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and + /// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's shape and /// an owned `Vec` of its data. /// /// ``` @@ -227,14 +218,14 @@ impl Value { /// /// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?; /// assert_eq!(extracted_data, array); - /// assert_eq!(extracted_shape, [2]); + /// assert_eq!(**extracted_shape, [2]); /// # Ok(()) /// # } /// ``` - pub fn try_extract_raw_string_tensor(&self) -> Result<(&[i64], Vec)> { - extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| { - let strings = extract_strings(ptr, dimensions)?; - Ok((dimensions, strings)) + pub fn try_extract_raw_string_tensor(&self) -> Result<(&Shape, Vec)> { + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, shape)| { + let strings = extract_strings(ptr, shape)?; + Ok((shape, strings)) }) } @@ -246,13 +237,13 @@ impl Value { /// # let allocator = Allocator::default(); /// let tensor = Tensor::::new(&allocator, [1, 128, 128, 3])?; /// - /// assert_eq!(tensor.shape(), [1, 128, 128, 3]); + /// assert_eq!(**tensor.shape(), [1, 128, 128, 3]); /// # Ok(()) /// # } /// ``` - pub fn shape(&self) -> &[i64] { + pub fn shape(&self) -> &Shape { match self.dtype() { - ValueType::Tensor { dimensions, .. } => dimensions, + ValueType::Tensor { shape, .. } => shape, _ => unreachable!() } } @@ -263,9 +254,9 @@ fn extract_tensor<'t>( dtype: &'t ValueType, memory_info: &MemoryInfo, expected_ty: TensorElementType -) -> Result<(*mut ort_sys::OrtValue, &'t [i64])> { +) -> Result<(*mut ort_sys::OrtValue, &'t Shape)> { match dtype { - ValueType::Tensor { ty, dimensions, .. } => { + ValueType::Tensor { ty, shape, .. } => { if !memory_info.is_cpu_accessible() { return Err(Error::new(format!( "Cannot extract from value on device `{}`, which is not CPU accessible", @@ -274,7 +265,7 @@ fn extract_tensor<'t>( } if *ty == expected_ty { - Ok((ptr, dimensions)) + Ok((ptr, shape)) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from Tensor<{}>", expected_ty, ty))) } @@ -289,9 +280,8 @@ unsafe fn data_ptr(ptr: *mut ort_sys::OrtValue) -> Result<*mut c_void> { Ok(output_array_ptr) } -fn extract_strings(ptr: *mut ort_sys::OrtValue, dimensions: &[i64]) -> Result> { - let len = element_count(dimensions); - +fn extract_strings(ptr: *mut ort_sys::OrtValue, shape: &Shape) -> Result> { + let len = shape.num_elements(); // Total length of string data, not including \0 suffix let mut total_length = 0; ortsys![unsafe GetStringTensorDataLength(ptr, &mut total_length)?]; @@ -368,7 +358,7 @@ impl Tensor { self.try_extract_tensor_mut().expect("Failed to extract tensor") } - /// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an immutable + /// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's shapes and an immutable /// view into its data. /// /// ``` @@ -379,15 +369,15 @@ impl Tensor { /// /// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor(); /// assert_eq!(extracted_data, &array); - /// assert_eq!(extracted_shape, [5]); + /// assert_eq!(**extracted_shape, [5]); /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor(&self) -> (&[i64], &[T]) { + pub fn extract_raw_tensor(&self) -> (&Shape, &[T]) { self.try_extract_raw_tensor().expect("Failed to extract tensor") } - /// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a mutable view + /// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's shapes and a mutable view /// into its data. /// /// ``` @@ -403,7 +393,7 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor_mut(&mut self) -> (&[i64], &mut [T]) { + pub fn extract_raw_tensor_mut(&mut self) -> (&Shape, &mut [T]) { self.try_extract_raw_tensor_mut().expect("Failed to extract tensor") } } diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index 31337e1..3971796 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -1,7 +1,7 @@ mod create; mod extract; -use alloc::{string::String, sync::Arc, vec}; +use alloc::sync::Arc; use core::{ fmt::{self, Debug}, marker::PhantomData, @@ -16,8 +16,7 @@ use crate::{ error::Result, memory::{Allocator, MemoryInfo}, ortsys, - tensor::{IntoTensorElementType, TensorElementType}, - util::element_count + tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType} }; pub trait TensorValueTypeMarker: ValueTypeMarker { @@ -93,10 +92,13 @@ impl DynTensor { /// # Ok(()) /// # } /// ``` - pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl ToDimensions) -> Result { + pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl Into) -> Result { + Self::new_inner(allocator, data_type, shape.into()) + } + + fn new_inner(allocator: &Allocator, data_type: TensorElementType, shape: Shape) -> Result { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - let shape = shape.to_dimensions(None)?; let shape_ptr: *const i64 = shape.as_ptr(); let shape_len = shape.len(); @@ -118,7 +120,7 @@ impl DynTensor { let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut(); ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)]; - unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(element_count(&shape))) }; + unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(shape.num_elements())) }; } Ok(Value { @@ -126,8 +128,8 @@ impl DynTensor { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, dtype: ValueType::Tensor { ty: data_type, - dimensions: shape, - dimension_symbols: vec![String::default(); shape_len] + shape, + dimension_symbols: SymbolicDimensions::empty(shape_len) }, drop: true, memory_info: MemoryInfo::from_value(value_ptr), @@ -338,7 +340,7 @@ mod tests { use super::Tensor; use crate::{ memory::Allocator, - tensor::TensorElementType, + tensor::{Shape, SymbolicDimensions, TensorElementType}, value::{TensorRef, ValueType} }; @@ -350,12 +352,12 @@ mod tests { assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Float32)); assert_eq!(value.dtype(), &ValueType::Tensor { ty: TensorElementType::Float32, - dimensions: vec![v.len() as i64], - dimension_symbols: vec![String::default()] + shape: Shape::new([v.len() as i64]), + dimension_symbols: SymbolicDimensions::empty(1) }); let (shape, data) = value.extract_raw_tensor(); - assert_eq!(shape, vec![v.len() as i64]); + assert_eq!(&**shape, [v.len() as i64]); assert_eq!(data, &v); Ok(()) @@ -411,7 +413,7 @@ mod tests { let value = Tensor::from_string_array((vec![v.len() as i64], &*v))?; let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?; - assert_eq!(extracted_shape, [v.len() as i64]); + assert_eq!(&**extracted_shape, [v.len() as i64]); assert_eq!(extracted_view, v); Ok(()) @@ -437,7 +439,7 @@ mod tests { #[test] fn test_tensor_index() -> crate::Result<()> { - let mut tensor = Tensor::new(&Allocator::default(), [1, 3, 224, 224])?; + let mut tensor = Tensor::new(&Allocator::default(), Shape::new([1, 3, 224, 224]))?; tensor[[0, 2, 42, 42]] = 1.0; assert_eq!(tensor[[0, 2, 42, 42]], 1.0); diff --git a/src/value/type.rs b/src/value/type.rs index eb068a0..439a6e8 100644 --- a/src/value/type.rs +++ b/src/value/type.rs @@ -1,39 +1,45 @@ use alloc::{ boxed::Box, - string::{String, ToString}, - vec, - vec::Vec -}; -use core::{ - ffi::{CStr, c_char}, - fmt, ptr + string::{String, ToString} }; +use core::{ffi::CStr, fmt, ptr}; -use crate::{ortsys, tensor::TensorElementType, util::with_cstr_ptr_array}; +use smallvec::{SmallVec, smallvec}; + +use crate::{ + ortsys, + tensor::{Shape, SymbolicDimensions, TensorElementType}, + util::with_cstr_ptr_array +}; /// The type of a [`Value`][super::Value], or a session input/output. /// /// ``` /// # use std::sync::Arc; -/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType}; +/// # use ort::{session::Session, tensor::{Shape, SymbolicDimensions}, value::{ValueType, Tensor}, tensor::TensorElementType}; /// # fn main() -> ort::Result<()> { /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// // `ValueType`s can be obtained from session inputs/outputs: /// let input = &session.inputs[0]; /// assert_eq!(input.input_type, ValueType::Tensor { /// ty: TensorElementType::Float32, -/// // Our model has 3 dynamic dimensions, represented by -1 -/// dimensions: vec![-1, -1, -1, 3], +/// // Our model's input has 3 dynamic dimensions, represented by -1 +/// shape: Shape::new([-1, -1, -1, 3]), /// // Dynamic dimensions may also have names. -/// dimension_symbols: vec!["unk__31".to_string(), "unk__32".to_string(), "unk__33".to_string(), String::default()] +/// dimension_symbols: SymbolicDimensions::new([ +/// "unk__31".to_string(), +/// "unk__32".to_string(), +/// "unk__33".to_string(), +/// String::default() +/// ]) /// }); /// /// // ...or by `Value`s created in Rust or output by a session. /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; /// assert_eq!(value.dtype(), &ValueType::Tensor { /// ty: TensorElementType::Int64, -/// dimensions: vec![5], -/// dimension_symbols: vec![String::default()] +/// shape: Shape::new([5]), +/// dimension_symbols: SymbolicDimensions::new([String::default()]) /// }); /// # Ok(()) /// # } @@ -44,15 +50,16 @@ pub enum ValueType { Tensor { /// Element type of the tensor. ty: TensorElementType, - /// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an + /// Shape of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an /// [`Input`]/[`Output`]), the dimension will be `-1`. /// - /// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions. + /// Actual tensor values (i.e. not [`Input`] or [`Output`] definitions), which have a known dimension, will + /// always have non-negative dimensions. /// /// [`Input`]: crate::session::Input /// [`Output`]: crate::session::Output - dimensions: Vec, - dimension_symbols: Vec + shape: Shape, + dimension_symbols: SymbolicDimensions }, /// A sequence (vector) of other `Value`s. /// @@ -136,11 +143,11 @@ impl ValueType { pub(crate) fn to_tensor_type_info(&self) -> Option<*mut ort_sys::OrtTensorTypeAndShapeInfo> { match self { - Self::Tensor { ty, dimensions, dimension_symbols } => { + Self::Tensor { ty, shape, dimension_symbols } => { let mut info_ptr = ptr::null_mut(); ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr).expect("infallible")]; ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into()).expect("infallible")]; - ortsys![unsafe SetDimensions(info_ptr, dimensions.as_ptr(), dimensions.len()).expect("infallible")]; + ortsys![unsafe SetDimensions(info_ptr, shape.as_ptr(), shape.len()).expect("infallible")]; with_cstr_ptr_array(dimension_symbols, &|ptrs| { ortsys![unsafe SetSymbolicDimensions(info_ptr, ptrs.as_ptr().cast_mut(), dimension_symbols.len()).expect("infallible")]; Ok(()) @@ -152,20 +159,22 @@ impl ValueType { } } - /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. + /// Returns the shape of this value type if it is a tensor, or `None` if it is a sequence or map. /// /// ``` - /// # use ort::value::Tensor; + /// # use ort::value::{Tensor, DynValue}; /// # fn main() -> ort::Result<()> { - /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; - /// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5])); + /// let value: DynValue = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?.into_dyn(); + /// + /// let shape = value.dtype().tensor_shape().unwrap(); + /// assert_eq!(**shape, [5]); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn tensor_dimensions(&self) -> Option<&Vec> { + pub fn tensor_shape(&self) -> Option<&Shape> { match self { - ValueType::Tensor { dimensions, .. } => Some(dimensions), + ValueType::Tensor { shape, .. } => Some(shape), _ => None } } @@ -213,9 +222,9 @@ impl ValueType { impl fmt::Display for ValueType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ValueType::Tensor { ty, dimensions, dimension_symbols } => { + ValueType::Tensor { ty, shape, dimension_symbols } => { write!(f, "Tensor<{ty}>(")?; - for (i, dimension) in dimensions.iter().copied().enumerate() { + for (i, dimension) in shape.iter().copied().enumerate() { if dimension == -1 { let sym = &dimension_symbols[i]; if sym.is_empty() { @@ -226,7 +235,7 @@ impl fmt::Display for ValueType { } else { dimension.fmt(f)?; } - if i != dimensions.len() - 1 { + if i != shape.len() - 1 { f.write_str(", ")?; } } @@ -248,10 +257,10 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys let mut num_dims = 0; ortsys![unsafe GetDimensionsCount(info_ptr, &mut num_dims).expect("infallible")]; - let mut node_dims: Vec = vec![0; num_dims]; + let mut node_dims = Shape::empty(num_dims); ortsys![unsafe GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims).expect("infallible")]; - let mut symbolic_dims: Vec<*const c_char> = vec![ptr::null(); num_dims]; + let mut symbolic_dims: SmallVec<_, 4> = smallvec![ptr::null(); num_dims]; ortsys![unsafe GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims).expect("infallible")]; let dimension_symbols = symbolic_dims @@ -261,7 +270,7 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys ValueType::Tensor { ty: type_sys.into(), - dimensions: node_dims, + shape: node_dims, dimension_symbols } } @@ -288,14 +297,17 @@ unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeIn #[cfg(test)] mod tests { use super::ValueType; - use crate::{ortsys, tensor::TensorElementType}; + use crate::{ + ortsys, + tensor::{Shape, SymbolicDimensions, TensorElementType} + }; #[test] fn test_to_from_tensor_info() -> crate::Result<()> { let ty = ValueType::Tensor { ty: TensorElementType::Float32, - dimensions: vec![-1, 32, 4, 32], - dimension_symbols: vec!["d1".to_string(), String::default(), String::default(), String::default()] + shape: Shape::new([-1, 32, 4, 32]), + dimension_symbols: SymbolicDimensions::new(["d1".to_string(), String::default(), String::default(), String::default()]) }; let ty_ptr = ty.to_tensor_type_info().expect(""); let ty_d = unsafe { super::extract_data_type_from_tensor_info(ty_ptr) }; diff --git a/tests/mnist.rs b/tests/mnist.rs index d2a55bf..c97a266 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -26,11 +26,11 @@ fn mnist_5() -> ort::Result<()> { assert_eq!(metadata.name()?, "CNTKGraph"); assert_eq!(metadata.producer()?, "CNTK"); - let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); - let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); + let input0_shape = session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"); + let output0_shape = session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"); - assert_eq!(input0_shape, &[1, 1, 28, 28]); - assert_eq!(output0_shape, &[1, 10]); + assert_eq!(&**input0_shape, &[1, 1, 28, 28]); + assert_eq!(&**output0_shape, &[1, 10]); input0_shape }; diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index 3150715..7e90bc0 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -33,11 +33,11 @@ fn squeezenet_mushroom() -> ort::Result<()> { assert_eq!(metadata.name()?, "main_graph"); assert_eq!(metadata.producer()?, "pytorch"); - let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); - let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); + let input0_shape = session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"); + let output0_shape = session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"); - assert_eq!(input0_shape, &[1, 3, 224, 224]); - assert_eq!(output0_shape, &[1, 1000]); + assert_eq!(&**input0_shape, [1, 3, 224, 224]); + assert_eq!(&**output0_shape, [1, 1000]); input0_shape }; diff --git a/tests/upsample.rs b/tests/upsample.rs index f81a9c1..b656163 100644 --- a/tests/upsample.rs +++ b/tests/upsample.rs @@ -63,8 +63,8 @@ fn upsample() -> ort::Result<()> { assert_eq!(metadata.name()?, "tf2onnx"); assert_eq!(metadata.producer()?, "tf2onnx"); - assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]); - assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]); + assert_eq!(&**session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"), [-1, -1, -1, 3]); + assert_eq!(&**session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"), [-1, -1, -1, 3]); } // Load image, converting to RGB format @@ -101,8 +101,8 @@ fn upsample_with_ort_model() -> ort::Result<()> { .commit_from_memory_directly(&session_data) // Zero-copy. .expect("Could not read model from memory"); - assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]); - assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]); + assert_eq!(&**session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"), [-1, -1, -1, 3]); + assert_eq!(&**session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"), [-1, -1, -1, 3]); // Load image, converting to RGB format let image_buffer = load_input_image(IMAGE_TO_LOAD);