refactor!: opaque tensor shape, rename dimensions to shape

This commit is contained in:
Carson M.
2025-03-12 22:28:18 -05:00
parent ee0e0c3c26
commit 909e41f26a
13 changed files with 309 additions and 217 deletions

View File

@@ -49,7 +49,7 @@ fn main() -> ort::Result<()> {
stdout.flush().unwrap();
for _ in 0..GEN_TOKENS {
// Raw tensor construction takes a tuple of (dimensions, data).
// Raw tensor construction takes a tuple of (shape, data).
// The model expects our input to have shape [B, _, S]
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run(inputs![input])?;

View File

@@ -38,6 +38,7 @@ pub mod tensor;
#[cfg(feature = "training")]
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
pub mod training;
#[doc(hidden)]
pub mod util;
pub mod value;

View File

@@ -6,14 +6,13 @@ use core::{
slice
};
use smallvec::SmallVec;
use crate::{
AsPointer,
error::{Error, Result},
memory::{Allocator, MemoryInfo, MemoryType},
ortsys,
session::{Input, Output},
tensor::Shape,
util::with_cstr,
value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType}
};
@@ -350,9 +349,9 @@ impl KernelContext {
Ok(NonNull::new(value_ptr.cast_mut()).map(|c| ValueRef::new(unsafe { Value::from_ptr_nodrop(c, None) })))
}
pub fn output(&self, idx: usize, shape: impl IntoIterator<Item = i64>) -> Result<Option<ValueRefMut<'_>>> {
pub fn output(&self, idx: usize, shape: impl Into<Shape>) -> Result<Option<ValueRefMut<'_>>> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
let shape = shape.into_iter().collect::<SmallVec<i64, 4>>();
let shape = shape.into();
ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx, shape.as_ptr(), shape.len(), &mut value_ptr)?];
Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) })))
}

View File

@@ -1,9 +1,126 @@
//! Traits related to [`Tensor`](crate::value::Tensor)s.
//! Traits and types related to [`Tensor`](crate::value::Tensor)s.
#[cfg(feature = "ndarray")]
mod ndarray;
mod types;
use alloc::{string::String, vec::Vec};
use core::{
fmt,
ops::{Deref, DerefMut}
};
use smallvec::{SmallVec, smallvec};
#[cfg(feature = "ndarray")]
pub use self::ndarray::ArrayExtensions;
pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};
#[derive(Default, Clone, PartialEq, Eq)]
pub struct Shape {
inner: SmallVec<i64, 4>
}
impl Shape {
pub fn new(dims: impl IntoIterator<Item = i64>) -> Self {
Self { inner: dims.into_iter().collect() }
}
pub fn empty(rank: usize) -> Self {
Self { inner: smallvec![0; rank] }
}
#[doc(alias = "numel")]
pub fn num_elements(&self) -> usize {
let mut size = 1usize;
for dim in &self.inner {
if *dim < 0 {
return 0;
}
size *= *dim as usize;
}
size
}
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn to_ixdyn(&self) -> ::ndarray::IxDyn {
use ::ndarray::IntoDimension;
self.inner.iter().map(|d| *d as usize).collect::<Vec<usize>>().into_dimension()
}
}
impl fmt::Debug for Shape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.inner.iter()).finish()
}
}
impl fmt::Display for Shape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.inner.iter()).finish()
}
}
impl From<Vec<i64>> for Shape {
fn from(value: Vec<i64>) -> Self {
Self { inner: SmallVec::from(value) }
}
}
impl From<&[i64]> for Shape {
fn from(value: &[i64]) -> Self {
Self { inner: SmallVec::from(value) }
}
}
impl<const N: usize> From<[i64; N]> for Shape {
fn from(value: [i64; N]) -> Self {
Self { inner: SmallVec::from(value) }
}
}
impl FromIterator<i64> for Shape {
fn from_iter<T: IntoIterator<Item = i64>>(iter: T) -> Self {
Self { inner: iter.into_iter().collect() }
}
}
impl Deref for Shape {
type Target = [i64];
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for Shape {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SymbolicDimensions(SmallVec<String, 4>);
impl SymbolicDimensions {
pub fn new(dims: impl IntoIterator<Item = String>) -> Self {
Self(dims.into_iter().collect())
}
pub fn empty(rank: usize) -> Self {
Self(smallvec![String::default(); rank])
}
}
impl FromIterator<String> for SymbolicDimensions {
fn from_iter<T: IntoIterator<Item = String>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
impl Deref for SymbolicDimensions {
type Target = [String];
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@@ -303,17 +303,6 @@ impl<T> Drop for OnceLock<T> {
#[doc(hidden)]
pub fn cold() {}
pub fn element_count(shape: &[i64]) -> usize {
let mut size = 1usize;
for dim in shape {
if *dim < 0 {
return 0;
}
size *= *dim as usize;
}
size
}
#[inline]
pub(crate) fn with_cstr<T>(bytes: &[u8], f: &dyn Fn(&CStr) -> Result<T>) -> Result<T> {
fn run_with_heap_cstr<T>(bytes: &[u8], f: &dyn Fn(&CStr) -> Result<T>) -> Result<T> {

View File

@@ -20,8 +20,7 @@ use crate::{
error::{Error, Result},
memory::Allocator,
ortsys,
tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType},
util::element_count
tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType}
};
pub trait MapValueTypeMarker: ValueTypeMarker {
@@ -109,7 +108,7 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
if K::into_tensor_element_type() != TensorElementType::String {
let dtype = key_value.dtype();
let (key_tensor_shape, key_tensor) = match dtype {
ValueType::Tensor { ty, dimensions, .. } => {
ValueType::Tensor { ty, shape, .. } => {
let mem = key_value.memory_info();
if !mem.is_cpu_accessible() {
return Err(Error::new(format!(
@@ -124,8 +123,7 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
let len = element_count(dimensions);
(dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) })
(shape, unsafe { slice::from_raw_parts(output_array_ptr, shape.num_elements()) })
} else {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,

View File

@@ -1,4 +1,4 @@
use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec, vec::Vec};
use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec::Vec};
use core::{
any::Any,
ffi::c_void,
@@ -17,8 +17,7 @@ use crate::{
error::{Error, ErrorCode, Result},
memory::{Allocator, MemoryInfo},
ortsys,
tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data},
util::element_count,
tensor::{PrimitiveTensorElementType, Shape, SymbolicDimensions, TensorElementType, Utf8Data},
value::{Value, ValueInner, ValueType}
};
@@ -78,8 +77,8 @@ impl Tensor<String> {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
dtype: ValueType::Tensor {
ty: TensorElementType::String,
dimensions: shape,
dimension_symbols: vec![String::default(); shape_len]
shape,
dimension_symbols: SymbolicDimensions::empty(shape_len)
},
memory_info: MemoryInfo::from_value(value_ptr),
drop: true,
@@ -109,7 +108,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// # Ok(())
/// # }
/// ```
pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result<Tensor<T>> {
pub fn new(allocator: &Allocator, shape: impl Into<Shape>) -> Result<Tensor<T>> {
let tensor = DynTensor::new(allocator, T::into_tensor_element_type(), shape)?;
Ok(unsafe { tensor.transmute_type() })
}
@@ -141,17 +140,16 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
///
/// Creating string tensors requires a separate method; see [`Tensor::from_string_array`].
pub fn from_array(input: impl OwnedTensorArrayData<T>) -> Result<Tensor<T>> {
let TensorArrayDataParts { shape, ptr, num_elements, guard } = input.into_parts()?;
tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), num_elements, size_of::<T>(), T::into_tensor_element_type(), guard)
let TensorArrayDataParts { shape, ptr, guard } = input.into_parts()?;
tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), size_of::<T>(), T::into_tensor_element_type(), guard)
.map(|tensor| unsafe { tensor.transmute_type() })
}
}
fn tensor_from_array(
memory_info: MemoryInfo,
shape: Vec<i64>,
shape: Shape,
data: *mut c_void,
num_elements: usize,
element_size: usize,
element_type: TensorElementType,
guard: Option<Box<dyn Any>>
@@ -162,7 +160,7 @@ fn tensor_from_array(
unsafe CreateTensorWithDataAsOrtValue(
memory_info.ptr(),
data,
num_elements * element_size,
shape.num_elements() * element_size,
shape.as_ptr(),
shape.len(),
element_type.into(),
@@ -176,8 +174,8 @@ fn tensor_from_array(
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
dtype: ValueType::Tensor {
ty: element_type,
dimension_symbols: vec![String::default(); shape.len()],
dimensions: shape
dimension_symbols: SymbolicDimensions::empty(shape.len()),
shape
},
drop: true,
memory_info: Some(memory_info),
@@ -219,15 +217,11 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> {
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
pub fn from_array_view(input: impl TensorArrayData<T> + 'a) -> Result<TensorRef<'a, T>> {
let (shape, data, guard) = input.ref_parts()?;
let num_elements = element_count(&shape);
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::<T>(), T::into_tensor_element_type(), guard).map(
|tensor| {
let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() });
tensor.upgradable = false;
tensor
}
)
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() });
tensor.upgradable = false;
tensor
})
}
}
@@ -262,21 +256,17 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
pub fn from_array_view_mut(mut input: impl TensorArrayDataMut<T>) -> Result<TensorRefMut<'a, T>> {
let (shape, data, guard) = input.ref_parts_mut()?;
let num_elements = element_count(&shape);
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::<T>(), T::into_tensor_element_type(), guard).map(
|tensor| {
let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() });
tensor.upgradable = false;
tensor
}
)
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() });
tensor.upgradable = false;
tensor
})
}
/// Create a mutable tensor view from a raw pointer and shape.
///
/// The length of data is determined by `T` and the given shape, so the given buffer must be at least
/// `shape.iter().product() * size_of::<T>()` bytes.
/// `shape.num_elements() * size_of::<T>()` bytes.
///
/// This function can be used to create data from raw device memory, e.g. to directly provide data to an execution
/// provider. For instance, to create a tensor from a raw CUDA buffer using [`cudarc`](https://docs.rs/cudarc):
@@ -288,7 +278,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
/// TensorRefMut::from_raw(
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?,
/// (*device_data.device_ptr() as usize as *mut ()).cast(),
/// vec![1, 3, 512, 512]
/// Shape::new([1, 3, 512, 512])
/// )?
/// };
/// ```
@@ -296,9 +286,8 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
/// # Safety
/// - The pointer must be valid for the device description provided by `MemoryInfo`.
/// - The returned tensor must outlive the data described by the data pointer.
pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Vec<i64>) -> Result<TensorRefMut<'a, T>> {
let num_elements = element_count(&shape);
tensor_from_array(info, shape, data, num_elements, size_of::<T>(), T::into_tensor_element_type(), None).map(|tensor| {
pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Shape) -> Result<TensorRefMut<'a, T>> {
tensor_from_array(info, shape, data, size_of::<T>(), T::into_tensor_element_type(), None).map(|tensor| {
let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() });
tensor.upgradable = false;
tensor
@@ -308,14 +297,14 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
pub trait TensorArrayData<I> {
#[allow(clippy::type_complexity)]
fn ref_parts(&self) -> Result<(Vec<i64>, &[I], Option<Box<dyn Any>>)>;
fn ref_parts(&self) -> Result<(Shape, &[I], Option<Box<dyn Any>>)>;
private_trait!();
}
pub trait TensorArrayDataMut<I>: TensorArrayData<I> {
#[allow(clippy::type_complexity)]
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [I], Option<Box<dyn Any>>)>;
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [I], Option<Box<dyn Any>>)>;
private_trait!();
}
@@ -327,20 +316,19 @@ pub trait OwnedTensorArrayData<I> {
}
pub struct TensorArrayDataParts<I> {
pub shape: Vec<i64>,
pub shape: Shape,
pub ptr: NonNull<I>,
pub num_elements: usize,
pub guard: Option<Box<dyn Any>>
}
pub trait ToDimensions {
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>>;
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Shape>;
}
macro_rules! impl_to_dimensions {
(@inner) => {
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>> {
let v: Vec<i64> = self
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Shape> {
let v = self
.iter()
.enumerate()
.map(|(i, c)| {
@@ -353,13 +341,17 @@ macro_rules! impl_to_dimensions {
))
}
})
.collect::<Result<_>>()?;
let sum = element_count(&v);
.collect::<Result<Shape>>()?;
if let Some(expected_size) = expected_size {
if sum != expected_size {
if v.num_elements() != expected_size {
Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)", v, sum, expected_size)
format!(
"Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)",
v,
v.num_elements(),
expected_size
)
))
} else {
Ok(v)
@@ -382,21 +374,21 @@ macro_rules! impl_to_dimensions {
}
impl ToDimensions for () {
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>> {
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Shape> {
match expected_size {
Some(1) | None => Ok(vec![]),
Some(1) | None => Ok(Shape::default()),
Some(_) => Err(Error::new_with_code(ErrorCode::InvalidArgument, "Expected data to have a length of exactly 1 for scalar shape"))
}
}
}
impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec<usize>, for Vec<i32>, for Vec<i64>);
impl_to_dimensions!(for Shape, for &[usize], for &[i32], for &[i64], for Vec<usize>, for Vec<i32>, for Vec<i64>);
impl_to_dimensions!(<N> for [usize; N], for [i32; N], for [i64; N]);
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &CowArray<'_, T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -409,8 +401,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &CowArra
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArcArray<T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -423,8 +415,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArcArray
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &Array<T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -437,8 +429,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &Array<T
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &mut Array<T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -455,27 +447,21 @@ impl<T: Clone + 'static, D: Dimension + 'static> OwnedTensorArrayData<T> for Arr
if self.is_standard_layout() {
// We can avoid the copy here and use the data as is
let mut this = Box::new(self);
let shape: Vec<i64> = this.shape().iter().map(|d| *d as i64).collect();
let shape: Shape = this.shape().iter().map(|d| *d as i64).collect();
// SAFETY: ndarrays internally store their pointer as NonNull
let ptr = unsafe { NonNull::new_unchecked(this.as_mut_ptr()) };
let num_elements = this.len();
Ok(TensorArrayDataParts {
shape,
ptr,
num_elements,
guard: Some(this)
})
assert_eq!(this.len(), shape.num_elements());
Ok(TensorArrayDataParts { shape, ptr, guard: Some(this) })
} else {
// Need to do a copy here to get data in to standard layout
let mut contiguous_array = self.as_standard_layout().into_owned();
let shape: Vec<i64> = contiguous_array.shape().iter().map(|d| *d as i64).collect();
let shape: Shape = contiguous_array.shape().iter().map(|d| *d as i64).collect();
// SAFETY: ndarrays internally store their pointer as NonNull
let ptr = unsafe { NonNull::new_unchecked(contiguous_array.as_mut_ptr()) };
let num_elements: usize = contiguous_array.len();
assert_eq!(contiguous_array.len(), shape.num_elements());
Ok(TensorArrayDataParts {
shape,
ptr,
num_elements,
guard: Some(Box::new(contiguous_array))
})
}
@@ -487,8 +473,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> OwnedTensorArrayData<T> for Arr
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayView<'_, T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -501,8 +487,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayVie
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayViewMut<'_, T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -515,8 +501,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayVie
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for ArrayViewMut<'_, T, D> {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice_mut()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -529,8 +515,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for Array
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for &mut Array<T, D> {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option<Box<dyn Any>>)> {
let shape = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice_mut()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
@@ -541,7 +527,7 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for &mut
}
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &[T]) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1, None))
}
@@ -550,7 +536,7 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &[T]) {
}
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &mut [T]) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1, None))
}
@@ -559,7 +545,7 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &mut [T]) {
}
impl<T: Clone + 'static, D: ToDimensions> TensorArrayDataMut<T> for (D, &mut [T]) {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1, None))
}
@@ -572,11 +558,10 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Vec<T>
let shape = self.0.to_dimensions(Some(self.1.len()))?;
// SAFETY: A `Vec` always has a non-null pointer.
let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) };
let num_elements: usize = self.1.len();
assert_eq!(shape.num_elements(), self.1.len());
Ok(TensorArrayDataParts {
shape,
ptr,
num_elements,
guard: Some(Box::new(self.1))
})
}
@@ -589,11 +574,10 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
let shape = self.0.to_dimensions(Some(self.1.len()))?;
// SAFETY: A `Box` always has a non-null pointer.
let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) };
let num_elements: usize = self.1.len();
assert_eq!(shape.num_elements(), self.1.len());
Ok(TensorArrayDataParts {
shape,
ptr,
num_elements,
guard: Some(Box::new(self.1))
})
}
@@ -602,7 +586,7 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
}
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
let data = &*self.1;
Ok((shape, data, Some(Box::new(self.1.clone()))))
@@ -612,7 +596,7 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
}
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]>>) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
let data = &*self.1;
Ok((shape, data, Some(Box::new(self.1.clone()))))

View File

@@ -12,8 +12,7 @@ use crate::{
error::{Error, ErrorCode, Result},
memory::MemoryInfo,
ortsys,
tensor::{PrimitiveTensorElementType, TensorElementType},
util::element_count,
tensor::{PrimitiveTensorElementType, Shape, TensorElementType},
value::{Value, ValueType}
};
@@ -49,11 +48,8 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn try_extract_tensor<T: PrimitiveTensorElementType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
use ndarray::IntoDimension;
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| {
let shape = dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>().into_dimension();
Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape, data_ptr(ptr)?.cast::<T>()) })
})
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
.and_then(|(ptr, shape)| Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape.to_ixdyn(), data_ptr(ptr)?.cast::<T>()) }))
}
/// Attempt to extract the scalar from a tensor of type `T`.
@@ -80,11 +76,11 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
///
/// [`DynValue`]: crate::value::DynValue
pub fn try_extract_scalar<T: PrimitiveTensorElementType + Copy>(&self) -> Result<T> {
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| {
if !dimensions.is_empty() {
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, shape)| {
if !shape.is_empty() {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), dimensions.len())
format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), shape.len())
));
}
@@ -122,14 +118,11 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn try_extract_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
use ndarray::IntoDimension;
extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| {
let shape = dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>().into_dimension();
Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape, data_ptr(ptr)?.cast::<T>()) })
})
extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
.and_then(|(ptr, shape)| Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape.to_ixdyn(), data_ptr(ptr)?.cast::<T>()) }))
}
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's shape and an
/// immutable view into its data.
///
/// See also:
@@ -145,7 +138,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
///
/// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor::<i64>()?;
/// assert_eq!(extracted_data, &array);
/// assert_eq!(extracted_shape, [5]);
/// assert_eq!(**extracted_shape, [5]);
/// # Ok(())
/// # }
/// ```
@@ -157,12 +150,12 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// - The provided type `T` does not match the tensor's element type.
///
/// [`DynValue`]: crate::value::DynValue
pub fn try_extract_raw_tensor<T: PrimitiveTensorElementType>(&self) -> Result<(&[i64], &[T])> {
pub fn try_extract_raw_tensor<T: PrimitiveTensorElementType>(&self) -> Result<(&Shape, &[T])> {
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
.and_then(|(ptr, dimensions)| Ok((dimensions, unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::<T>(), element_count(dimensions)) })))
.and_then(|(ptr, shape)| Ok((shape, unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::<T>(), shape.num_elements()) })))
}
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's shape and a
/// mutable view into its data.
///
/// See also the infallible counterpart, [`Tensor::extract_raw_tensor_mut`], for typed [`Tensor<T>`]s.
@@ -175,7 +168,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
///
/// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor_mut::<i64>()?;
/// assert_eq!(extracted_data, &array);
/// assert_eq!(extracted_shape, [5]);
/// assert_eq!(**extracted_shape, [5]);
/// # Ok(())
/// # }
/// ```
@@ -187,9 +180,9 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// - The provided type `T` does not match the tensor's element type.
///
/// [`DynValue`]: crate::value::DynValue
pub fn try_extract_raw_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<(&[i64], &mut [T])> {
pub fn try_extract_raw_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<(&Shape, &mut [T])> {
extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
.and_then(|(ptr, dimensions)| Ok((dimensions, unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::<T>(), element_count(dimensions)) })))
.and_then(|(ptr, shape)| Ok((shape, unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::<T>(), shape.num_elements()) })))
}
/// Attempt to extract the underlying data into a Rust `ndarray`.
@@ -208,15 +201,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn try_extract_string_tensor(&self) -> Result<ndarray::ArrayD<String>> {
use ndarray::IntoDimension;
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| {
let strings = extract_strings(ptr, dimensions)?;
Ok(ndarray::Array::from_shape_vec(dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>().into_dimension(), strings)
.expect("Shape extracted from tensor didn't match tensor contents"))
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, shape)| {
let strings = extract_strings(ptr, shape)?;
Ok(ndarray::Array::from_shape_vec(shape.to_ixdyn(), strings).expect("Shape extracted from tensor didn't match tensor contents"))
})
}
/// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and
/// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's shape and
/// an owned `Vec` of its data.
///
/// ```
@@ -227,14 +218,14 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
///
/// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?;
/// assert_eq!(extracted_data, array);
/// assert_eq!(extracted_shape, [2]);
/// assert_eq!(**extracted_shape, [2]);
/// # Ok(())
/// # }
/// ```
pub fn try_extract_raw_string_tensor(&self) -> Result<(&[i64], Vec<String>)> {
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| {
let strings = extract_strings(ptr, dimensions)?;
Ok((dimensions, strings))
pub fn try_extract_raw_string_tensor(&self) -> Result<(&Shape, Vec<String>)> {
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, shape)| {
let strings = extract_strings(ptr, shape)?;
Ok((shape, strings))
})
}
@@ -246,13 +237,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// # let allocator = Allocator::default();
/// let tensor = Tensor::<f32>::new(&allocator, [1, 128, 128, 3])?;
///
/// assert_eq!(tensor.shape(), [1, 128, 128, 3]);
/// assert_eq!(**tensor.shape(), [1, 128, 128, 3]);
/// # Ok(())
/// # }
/// ```
pub fn shape(&self) -> &[i64] {
pub fn shape(&self) -> &Shape {
match self.dtype() {
ValueType::Tensor { dimensions, .. } => dimensions,
ValueType::Tensor { shape, .. } => shape,
_ => unreachable!()
}
}
@@ -263,9 +254,9 @@ fn extract_tensor<'t>(
dtype: &'t ValueType,
memory_info: &MemoryInfo,
expected_ty: TensorElementType
) -> Result<(*mut ort_sys::OrtValue, &'t [i64])> {
) -> Result<(*mut ort_sys::OrtValue, &'t Shape)> {
match dtype {
ValueType::Tensor { ty, dimensions, .. } => {
ValueType::Tensor { ty, shape, .. } => {
if !memory_info.is_cpu_accessible() {
return Err(Error::new(format!(
"Cannot extract from value on device `{}`, which is not CPU accessible",
@@ -274,7 +265,7 @@ fn extract_tensor<'t>(
}
if *ty == expected_ty {
Ok((ptr, dimensions))
Ok((ptr, shape))
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from Tensor<{}>", expected_ty, ty)))
}
@@ -289,9 +280,8 @@ unsafe fn data_ptr(ptr: *mut ort_sys::OrtValue) -> Result<*mut c_void> {
Ok(output_array_ptr)
}
fn extract_strings(ptr: *mut ort_sys::OrtValue, dimensions: &[i64]) -> Result<Vec<String>> {
let len = element_count(dimensions);
fn extract_strings(ptr: *mut ort_sys::OrtValue, shape: &Shape) -> Result<Vec<String>> {
let len = shape.num_elements();
// Total length of string data, not including \0 suffix
let mut total_length = 0;
ortsys![unsafe GetStringTensorDataLength(ptr, &mut total_length)?];
@@ -368,7 +358,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
self.try_extract_tensor_mut().expect("Failed to extract tensor")
}
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an immutable
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's shapes and an immutable
/// view into its data.
///
/// ```
@@ -379,15 +369,15 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
///
/// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor();
/// assert_eq!(extracted_data, &array);
/// assert_eq!(extracted_shape, [5]);
/// assert_eq!(**extracted_shape, [5]);
/// # Ok(())
/// # }
/// ```
pub fn extract_raw_tensor(&self) -> (&[i64], &[T]) {
pub fn extract_raw_tensor(&self) -> (&Shape, &[T]) {
self.try_extract_raw_tensor().expect("Failed to extract tensor")
}
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a mutable view
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's shapes and a mutable view
/// into its data.
///
/// ```
@@ -403,7 +393,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// # Ok(())
/// # }
/// ```
pub fn extract_raw_tensor_mut(&mut self) -> (&[i64], &mut [T]) {
pub fn extract_raw_tensor_mut(&mut self) -> (&Shape, &mut [T]) {
self.try_extract_raw_tensor_mut().expect("Failed to extract tensor")
}
}

View File

@@ -1,7 +1,7 @@
mod create;
mod extract;
use alloc::{string::String, sync::Arc, vec};
use alloc::sync::Arc;
use core::{
fmt::{self, Debug},
marker::PhantomData,
@@ -16,8 +16,7 @@ use crate::{
error::Result,
memory::{Allocator, MemoryInfo},
ortsys,
tensor::{IntoTensorElementType, TensorElementType},
util::element_count
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
};
pub trait TensorValueTypeMarker: ValueTypeMarker {
@@ -93,10 +92,13 @@ impl DynTensor {
/// # Ok(())
/// # }
/// ```
pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl ToDimensions) -> Result<DynTensor> {
pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl Into<Shape>) -> Result<DynTensor> {
Self::new_inner(allocator, data_type, shape.into())
}
fn new_inner(allocator: &Allocator, data_type: TensorElementType, shape: Shape) -> Result<DynTensor> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
let shape = shape.to_dimensions(None)?;
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = shape.len();
@@ -118,7 +120,7 @@ impl DynTensor {
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];
unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(element_count(&shape))) };
unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(shape.num_elements())) };
}
Ok(Value {
@@ -126,8 +128,8 @@ impl DynTensor {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
dtype: ValueType::Tensor {
ty: data_type,
dimensions: shape,
dimension_symbols: vec![String::default(); shape_len]
shape,
dimension_symbols: SymbolicDimensions::empty(shape_len)
},
drop: true,
memory_info: MemoryInfo::from_value(value_ptr),
@@ -338,7 +340,7 @@ mod tests {
use super::Tensor;
use crate::{
memory::Allocator,
tensor::TensorElementType,
tensor::{Shape, SymbolicDimensions, TensorElementType},
value::{TensorRef, ValueType}
};
@@ -350,12 +352,12 @@ mod tests {
assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Float32));
assert_eq!(value.dtype(), &ValueType::Tensor {
ty: TensorElementType::Float32,
dimensions: vec![v.len() as i64],
dimension_symbols: vec![String::default()]
shape: Shape::new([v.len() as i64]),
dimension_symbols: SymbolicDimensions::empty(1)
});
let (shape, data) = value.extract_raw_tensor();
assert_eq!(shape, vec![v.len() as i64]);
assert_eq!(&**shape, [v.len() as i64]);
assert_eq!(data, &v);
Ok(())
@@ -411,7 +413,7 @@ mod tests {
let value = Tensor::from_string_array((vec![v.len() as i64], &*v))?;
let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?;
assert_eq!(extracted_shape, [v.len() as i64]);
assert_eq!(&**extracted_shape, [v.len() as i64]);
assert_eq!(extracted_view, v);
Ok(())
@@ -437,7 +439,7 @@ mod tests {
#[test]
fn test_tensor_index() -> crate::Result<()> {
let mut tensor = Tensor::new(&Allocator::default(), [1, 3, 224, 224])?;
let mut tensor = Tensor::new(&Allocator::default(), Shape::new([1, 3, 224, 224]))?;
tensor[[0, 2, 42, 42]] = 1.0;
assert_eq!(tensor[[0, 2, 42, 42]], 1.0);

View File

@@ -1,39 +1,45 @@
use alloc::{
boxed::Box,
string::{String, ToString},
vec,
vec::Vec
};
use core::{
ffi::{CStr, c_char},
fmt, ptr
string::{String, ToString}
};
use core::{ffi::CStr, fmt, ptr};
use crate::{ortsys, tensor::TensorElementType, util::with_cstr_ptr_array};
use smallvec::{SmallVec, smallvec};
use crate::{
ortsys,
tensor::{Shape, SymbolicDimensions, TensorElementType},
util::with_cstr_ptr_array
};
/// The type of a [`Value`][super::Value], or a session input/output.
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType};
/// # use ort::{session::Session, tensor::{Shape, SymbolicDimensions}, value::{ValueType, Tensor}, tensor::TensorElementType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// // `ValueType`s can be obtained from session inputs/outputs:
/// let input = &session.inputs[0];
/// assert_eq!(input.input_type, ValueType::Tensor {
/// ty: TensorElementType::Float32,
/// // Our model has 3 dynamic dimensions, represented by -1
/// dimensions: vec![-1, -1, -1, 3],
/// // Our model's input has 3 dynamic dimensions, represented by -1
/// shape: Shape::new([-1, -1, -1, 3]),
/// // Dynamic dimensions may also have names.
/// dimension_symbols: vec!["unk__31".to_string(), "unk__32".to_string(), "unk__33".to_string(), String::default()]
/// dimension_symbols: SymbolicDimensions::new([
/// "unk__31".to_string(),
/// "unk__32".to_string(),
/// "unk__33".to_string(),
/// String::default()
/// ])
/// });
///
/// // ...or by `Value`s created in Rust or output by a session.
/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?;
/// assert_eq!(value.dtype(), &ValueType::Tensor {
/// ty: TensorElementType::Int64,
/// dimensions: vec![5],
/// dimension_symbols: vec![String::default()]
/// shape: Shape::new([5]),
/// dimension_symbols: SymbolicDimensions::new([String::default()])
/// });
/// # Ok(())
/// # }
@@ -44,15 +50,16 @@ pub enum ValueType {
Tensor {
/// Element type of the tensor.
ty: TensorElementType,
/// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an
/// Shape of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an
/// [`Input`]/[`Output`]), the dimension will be `-1`.
///
/// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions.
/// Actual tensor values (i.e. not [`Input`] or [`Output`] definitions), which have a known dimension, will
/// always have non-negative dimensions.
///
/// [`Input`]: crate::session::Input
/// [`Output`]: crate::session::Output
dimensions: Vec<i64>,
dimension_symbols: Vec<String>
shape: Shape,
dimension_symbols: SymbolicDimensions
},
/// A sequence (vector) of other `Value`s.
///
@@ -136,11 +143,11 @@ impl ValueType {
pub(crate) fn to_tensor_type_info(&self) -> Option<*mut ort_sys::OrtTensorTypeAndShapeInfo> {
match self {
Self::Tensor { ty, dimensions, dimension_symbols } => {
Self::Tensor { ty, shape, dimension_symbols } => {
let mut info_ptr = ptr::null_mut();
ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr).expect("infallible")];
ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into()).expect("infallible")];
ortsys![unsafe SetDimensions(info_ptr, dimensions.as_ptr(), dimensions.len()).expect("infallible")];
ortsys![unsafe SetDimensions(info_ptr, shape.as_ptr(), shape.len()).expect("infallible")];
with_cstr_ptr_array(dimension_symbols, &|ptrs| {
ortsys![unsafe SetSymbolicDimensions(info_ptr, ptrs.as_ptr().cast_mut(), dimension_symbols.len()).expect("infallible")];
Ok(())
@@ -152,20 +159,22 @@ impl ValueType {
}
}
/// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map.
/// Returns the shape of this value type if it is a tensor, or `None` if it is a sequence or map.
///
/// ```
/// # use ort::value::Tensor;
/// # use ort::value::{Tensor, DynValue};
/// # fn main() -> ort::Result<()> {
/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?;
/// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5]));
/// let value: DynValue = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?.into_dyn();
///
/// let shape = value.dtype().tensor_shape().unwrap();
/// assert_eq!(**shape, [5]);
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn tensor_dimensions(&self) -> Option<&Vec<i64>> {
pub fn tensor_shape(&self) -> Option<&Shape> {
match self {
ValueType::Tensor { dimensions, .. } => Some(dimensions),
ValueType::Tensor { shape, .. } => Some(shape),
_ => None
}
}
@@ -213,9 +222,9 @@ impl ValueType {
impl fmt::Display for ValueType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ValueType::Tensor { ty, dimensions, dimension_symbols } => {
ValueType::Tensor { ty, shape, dimension_symbols } => {
write!(f, "Tensor<{ty}>(")?;
for (i, dimension) in dimensions.iter().copied().enumerate() {
for (i, dimension) in shape.iter().copied().enumerate() {
if dimension == -1 {
let sym = &dimension_symbols[i];
if sym.is_empty() {
@@ -226,7 +235,7 @@ impl fmt::Display for ValueType {
} else {
dimension.fmt(f)?;
}
if i != dimensions.len() - 1 {
if i != shape.len() - 1 {
f.write_str(", ")?;
}
}
@@ -248,10 +257,10 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(info_ptr, &mut num_dims).expect("infallible")];
let mut node_dims: Vec<i64> = vec![0; num_dims];
let mut node_dims = Shape::empty(num_dims);
ortsys![unsafe GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims).expect("infallible")];
let mut symbolic_dims: Vec<*const c_char> = vec![ptr::null(); num_dims];
let mut symbolic_dims: SmallVec<_, 4> = smallvec![ptr::null(); num_dims];
ortsys![unsafe GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims).expect("infallible")];
let dimension_symbols = symbolic_dims
@@ -261,7 +270,7 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys
ValueType::Tensor {
ty: type_sys.into(),
dimensions: node_dims,
shape: node_dims,
dimension_symbols
}
}
@@ -288,14 +297,17 @@ unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeIn
#[cfg(test)]
mod tests {
use super::ValueType;
use crate::{ortsys, tensor::TensorElementType};
use crate::{
ortsys,
tensor::{Shape, SymbolicDimensions, TensorElementType}
};
#[test]
fn test_to_from_tensor_info() -> crate::Result<()> {
let ty = ValueType::Tensor {
ty: TensorElementType::Float32,
dimensions: vec![-1, 32, 4, 32],
dimension_symbols: vec!["d1".to_string(), String::default(), String::default(), String::default()]
shape: Shape::new([-1, 32, 4, 32]),
dimension_symbols: SymbolicDimensions::new(["d1".to_string(), String::default(), String::default(), String::default()])
};
let ty_ptr = ty.to_tensor_type_info().expect("");
let ty_d = unsafe { super::extract_data_type_from_tensor_info(ty_ptr) };

View File

@@ -26,11 +26,11 @@ fn mnist_5() -> ort::Result<()> {
assert_eq!(metadata.name()?, "CNTKGraph");
assert_eq!(metadata.producer()?, "CNTK");
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
let input0_shape = session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type");
let output0_shape = session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type");
assert_eq!(input0_shape, &[1, 1, 28, 28]);
assert_eq!(output0_shape, &[1, 10]);
assert_eq!(&**input0_shape, &[1, 1, 28, 28]);
assert_eq!(&**output0_shape, &[1, 10]);
input0_shape
};

View File

@@ -33,11 +33,11 @@ fn squeezenet_mushroom() -> ort::Result<()> {
assert_eq!(metadata.name()?, "main_graph");
assert_eq!(metadata.producer()?, "pytorch");
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
let input0_shape = session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type");
let output0_shape = session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type");
assert_eq!(input0_shape, &[1, 3, 224, 224]);
assert_eq!(output0_shape, &[1, 1000]);
assert_eq!(&**input0_shape, [1, 3, 224, 224]);
assert_eq!(&**output0_shape, [1, 1000]);
input0_shape
};

View File

@@ -63,8 +63,8 @@ fn upsample() -> ort::Result<()> {
assert_eq!(metadata.name()?, "tf2onnx");
assert_eq!(metadata.producer()?, "tf2onnx");
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(&**session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"), [-1, -1, -1, 3]);
assert_eq!(&**session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"), [-1, -1, -1, 3]);
}
// Load image, converting to RGB format
@@ -101,8 +101,8 @@ fn upsample_with_ort_model() -> ort::Result<()> {
.commit_from_memory_directly(&session_data) // Zero-copy.
.expect("Could not read model from memory");
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(&**session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"), [-1, -1, -1, 3]);
assert_eq!(&**session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"), [-1, -1, -1, 3]);
// Load image, converting to RGB format
let image_buffer = load_input_image(IMAGE_TO_LOAD);