mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor!: opaque tensor shape, rename dimensions to shape
This commit is contained in:
@@ -49,7 +49,7 @@ fn main() -> ort::Result<()> {
|
||||
stdout.flush().unwrap();
|
||||
|
||||
for _ in 0..GEN_TOKENS {
|
||||
// Raw tensor construction takes a tuple of (dimensions, data).
|
||||
// Raw tensor construction takes a tuple of (shape, data).
|
||||
// The model expects our input to have shape [B, _, S]
|
||||
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
|
||||
let outputs = session.run(inputs![input])?;
|
||||
|
||||
@@ -38,6 +38,7 @@ pub mod tensor;
|
||||
#[cfg(feature = "training")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
|
||||
pub mod training;
|
||||
#[doc(hidden)]
|
||||
pub mod util;
|
||||
pub mod value;
|
||||
|
||||
|
||||
@@ -6,14 +6,13 @@ use core::{
|
||||
slice
|
||||
};
|
||||
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use crate::{
|
||||
AsPointer,
|
||||
error::{Error, Result},
|
||||
memory::{Allocator, MemoryInfo, MemoryType},
|
||||
ortsys,
|
||||
session::{Input, Output},
|
||||
tensor::Shape,
|
||||
util::with_cstr,
|
||||
value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType}
|
||||
};
|
||||
@@ -350,9 +349,9 @@ impl KernelContext {
|
||||
Ok(NonNull::new(value_ptr.cast_mut()).map(|c| ValueRef::new(unsafe { Value::from_ptr_nodrop(c, None) })))
|
||||
}
|
||||
|
||||
pub fn output(&self, idx: usize, shape: impl IntoIterator<Item = i64>) -> Result<Option<ValueRefMut<'_>>> {
|
||||
pub fn output(&self, idx: usize, shape: impl Into<Shape>) -> Result<Option<ValueRefMut<'_>>> {
|
||||
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
|
||||
let shape = shape.into_iter().collect::<SmallVec<i64, 4>>();
|
||||
let shape = shape.into();
|
||||
ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx, shape.as_ptr(), shape.len(), &mut value_ptr)?];
|
||||
Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) })))
|
||||
}
|
||||
|
||||
@@ -1,9 +1,126 @@
|
||||
//! Traits related to [`Tensor`](crate::value::Tensor)s.
|
||||
//! Traits and types related to [`Tensor`](crate::value::Tensor)s.
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
mod ndarray;
|
||||
mod types;
|
||||
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::{
|
||||
fmt,
|
||||
ops::{Deref, DerefMut}
|
||||
};
|
||||
|
||||
use smallvec::{SmallVec, smallvec};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub use self::ndarray::ArrayExtensions;
|
||||
pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};
|
||||
|
||||
#[derive(Default, Clone, PartialEq, Eq)]
|
||||
pub struct Shape {
|
||||
inner: SmallVec<i64, 4>
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn new(dims: impl IntoIterator<Item = i64>) -> Self {
|
||||
Self { inner: dims.into_iter().collect() }
|
||||
}
|
||||
|
||||
pub fn empty(rank: usize) -> Self {
|
||||
Self { inner: smallvec![0; rank] }
|
||||
}
|
||||
|
||||
#[doc(alias = "numel")]
|
||||
pub fn num_elements(&self) -> usize {
|
||||
let mut size = 1usize;
|
||||
for dim in &self.inner {
|
||||
if *dim < 0 {
|
||||
return 0;
|
||||
}
|
||||
size *= *dim as usize;
|
||||
}
|
||||
size
|
||||
}
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn to_ixdyn(&self) -> ::ndarray::IxDyn {
|
||||
use ::ndarray::IntoDimension;
|
||||
self.inner.iter().map(|d| *d as usize).collect::<Vec<usize>>().into_dimension()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Shape {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list().entries(self.inner.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Shape {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list().entries(self.inner.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<i64>> for Shape {
|
||||
fn from(value: Vec<i64>) -> Self {
|
||||
Self { inner: SmallVec::from(value) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[i64]> for Shape {
|
||||
fn from(value: &[i64]) -> Self {
|
||||
Self { inner: SmallVec::from(value) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> From<[i64; N]> for Shape {
|
||||
fn from(value: [i64; N]) -> Self {
|
||||
Self { inner: SmallVec::from(value) }
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<i64> for Shape {
|
||||
fn from_iter<T: IntoIterator<Item = i64>>(iter: T) -> Self {
|
||||
Self { inner: iter.into_iter().collect() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Shape {
|
||||
type Target = [i64];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Shape {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SymbolicDimensions(SmallVec<String, 4>);
|
||||
|
||||
impl SymbolicDimensions {
|
||||
pub fn new(dims: impl IntoIterator<Item = String>) -> Self {
|
||||
Self(dims.into_iter().collect())
|
||||
}
|
||||
|
||||
pub fn empty(rank: usize) -> Self {
|
||||
Self(smallvec![String::default(); rank])
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<String> for SymbolicDimensions {
|
||||
fn from_iter<T: IntoIterator<Item = String>>(iter: T) -> Self {
|
||||
Self(iter.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SymbolicDimensions {
|
||||
type Target = [String];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
11
src/util.rs
11
src/util.rs
@@ -303,17 +303,6 @@ impl<T> Drop for OnceLock<T> {
|
||||
#[doc(hidden)]
|
||||
pub fn cold() {}
|
||||
|
||||
pub fn element_count(shape: &[i64]) -> usize {
|
||||
let mut size = 1usize;
|
||||
for dim in shape {
|
||||
if *dim < 0 {
|
||||
return 0;
|
||||
}
|
||||
size *= *dim as usize;
|
||||
}
|
||||
size
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn with_cstr<T>(bytes: &[u8], f: &dyn Fn(&CStr) -> Result<T>) -> Result<T> {
|
||||
fn run_with_heap_cstr<T>(bytes: &[u8], f: &dyn Fn(&CStr) -> Result<T>) -> Result<T> {
|
||||
|
||||
@@ -20,8 +20,7 @@ use crate::{
|
||||
error::{Error, Result},
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType},
|
||||
util::element_count
|
||||
tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType}
|
||||
};
|
||||
|
||||
pub trait MapValueTypeMarker: ValueTypeMarker {
|
||||
@@ -109,7 +108,7 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
if K::into_tensor_element_type() != TensorElementType::String {
|
||||
let dtype = key_value.dtype();
|
||||
let (key_tensor_shape, key_tensor) = match dtype {
|
||||
ValueType::Tensor { ty, dimensions, .. } => {
|
||||
ValueType::Tensor { ty, shape, .. } => {
|
||||
let mem = key_value.memory_info();
|
||||
if !mem.is_cpu_accessible() {
|
||||
return Err(Error::new(format!(
|
||||
@@ -124,8 +123,7 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
let len = element_count(dimensions);
|
||||
(dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) })
|
||||
(shape, unsafe { slice::from_raw_parts(output_array_ptr, shape.num_elements()) })
|
||||
} else {
|
||||
return Err(Error::new_with_code(
|
||||
ErrorCode::InvalidArgument,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec, vec::Vec};
|
||||
use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec::Vec};
|
||||
use core::{
|
||||
any::Any,
|
||||
ffi::c_void,
|
||||
@@ -17,8 +17,7 @@ use crate::{
|
||||
error::{Error, ErrorCode, Result},
|
||||
memory::{Allocator, MemoryInfo},
|
||||
ortsys,
|
||||
tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data},
|
||||
util::element_count,
|
||||
tensor::{PrimitiveTensorElementType, Shape, SymbolicDimensions, TensorElementType, Utf8Data},
|
||||
value::{Value, ValueInner, ValueType}
|
||||
};
|
||||
|
||||
@@ -78,8 +77,8 @@ impl Tensor<String> {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
dtype: ValueType::Tensor {
|
||||
ty: TensorElementType::String,
|
||||
dimensions: shape,
|
||||
dimension_symbols: vec![String::default(); shape_len]
|
||||
shape,
|
||||
dimension_symbols: SymbolicDimensions::empty(shape_len)
|
||||
},
|
||||
memory_info: MemoryInfo::from_value(value_ptr),
|
||||
drop: true,
|
||||
@@ -109,7 +108,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result<Tensor<T>> {
|
||||
pub fn new(allocator: &Allocator, shape: impl Into<Shape>) -> Result<Tensor<T>> {
|
||||
let tensor = DynTensor::new(allocator, T::into_tensor_element_type(), shape)?;
|
||||
Ok(unsafe { tensor.transmute_type() })
|
||||
}
|
||||
@@ -141,17 +140,16 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
///
|
||||
/// Creating string tensors requires a separate method; see [`Tensor::from_string_array`].
|
||||
pub fn from_array(input: impl OwnedTensorArrayData<T>) -> Result<Tensor<T>> {
|
||||
let TensorArrayDataParts { shape, ptr, num_elements, guard } = input.into_parts()?;
|
||||
tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), num_elements, size_of::<T>(), T::into_tensor_element_type(), guard)
|
||||
let TensorArrayDataParts { shape, ptr, guard } = input.into_parts()?;
|
||||
tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), size_of::<T>(), T::into_tensor_element_type(), guard)
|
||||
.map(|tensor| unsafe { tensor.transmute_type() })
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_from_array(
|
||||
memory_info: MemoryInfo,
|
||||
shape: Vec<i64>,
|
||||
shape: Shape,
|
||||
data: *mut c_void,
|
||||
num_elements: usize,
|
||||
element_size: usize,
|
||||
element_type: TensorElementType,
|
||||
guard: Option<Box<dyn Any>>
|
||||
@@ -162,7 +160,7 @@ fn tensor_from_array(
|
||||
unsafe CreateTensorWithDataAsOrtValue(
|
||||
memory_info.ptr(),
|
||||
data,
|
||||
num_elements * element_size,
|
||||
shape.num_elements() * element_size,
|
||||
shape.as_ptr(),
|
||||
shape.len(),
|
||||
element_type.into(),
|
||||
@@ -176,8 +174,8 @@ fn tensor_from_array(
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
dtype: ValueType::Tensor {
|
||||
ty: element_type,
|
||||
dimension_symbols: vec![String::default(); shape.len()],
|
||||
dimensions: shape
|
||||
dimension_symbols: SymbolicDimensions::empty(shape.len()),
|
||||
shape
|
||||
},
|
||||
drop: true,
|
||||
memory_info: Some(memory_info),
|
||||
@@ -219,15 +217,11 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> {
|
||||
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
|
||||
pub fn from_array_view(input: impl TensorArrayData<T> + 'a) -> Result<TensorRef<'a, T>> {
|
||||
let (shape, data, guard) = input.ref_parts()?;
|
||||
let num_elements = element_count(&shape);
|
||||
|
||||
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::<T>(), T::into_tensor_element_type(), guard).map(
|
||||
|tensor| {
|
||||
let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() });
|
||||
tensor.upgradable = false;
|
||||
tensor
|
||||
}
|
||||
)
|
||||
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
|
||||
let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() });
|
||||
tensor.upgradable = false;
|
||||
tensor
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,21 +256,17 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
|
||||
pub fn from_array_view_mut(mut input: impl TensorArrayDataMut<T>) -> Result<TensorRefMut<'a, T>> {
|
||||
let (shape, data, guard) = input.ref_parts_mut()?;
|
||||
let num_elements = element_count(&shape);
|
||||
|
||||
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::<T>(), T::into_tensor_element_type(), guard).map(
|
||||
|tensor| {
|
||||
let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() });
|
||||
tensor.upgradable = false;
|
||||
tensor
|
||||
}
|
||||
)
|
||||
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
|
||||
let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() });
|
||||
tensor.upgradable = false;
|
||||
tensor
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a mutable tensor view from a raw pointer and shape.
|
||||
///
|
||||
/// The length of data is determined by `T` and the given shape, so the given buffer must be at least
|
||||
/// `shape.iter().product() * size_of::<T>()` bytes.
|
||||
/// `shape.num_elements() * size_of::<T>()` bytes.
|
||||
///
|
||||
/// This function can be used to create data from raw device memory, e.g. to directly provide data to an execution
|
||||
/// provider. For instance, to create a tensor from a raw CUDA buffer using [`cudarc`](https://docs.rs/cudarc):
|
||||
@@ -288,7 +278,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
/// TensorRefMut::from_raw(
|
||||
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?,
|
||||
/// (*device_data.device_ptr() as usize as *mut ()).cast(),
|
||||
/// vec![1, 3, 512, 512]
|
||||
/// Shape::new([1, 3, 512, 512])
|
||||
/// )?
|
||||
/// };
|
||||
/// ```
|
||||
@@ -296,9 +286,8 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
/// # Safety
|
||||
/// - The pointer must be valid for the device description provided by `MemoryInfo`.
|
||||
/// - The returned tensor must outlive the data described by the data pointer.
|
||||
pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Vec<i64>) -> Result<TensorRefMut<'a, T>> {
|
||||
let num_elements = element_count(&shape);
|
||||
tensor_from_array(info, shape, data, num_elements, size_of::<T>(), T::into_tensor_element_type(), None).map(|tensor| {
|
||||
pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Shape) -> Result<TensorRefMut<'a, T>> {
|
||||
tensor_from_array(info, shape, data, size_of::<T>(), T::into_tensor_element_type(), None).map(|tensor| {
|
||||
let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() });
|
||||
tensor.upgradable = false;
|
||||
tensor
|
||||
@@ -308,14 +297,14 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
|
||||
pub trait TensorArrayData<I> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[I], Option<Box<dyn Any>>)>;
|
||||
fn ref_parts(&self) -> Result<(Shape, &[I], Option<Box<dyn Any>>)>;
|
||||
|
||||
private_trait!();
|
||||
}
|
||||
|
||||
pub trait TensorArrayDataMut<I>: TensorArrayData<I> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [I], Option<Box<dyn Any>>)>;
|
||||
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [I], Option<Box<dyn Any>>)>;
|
||||
|
||||
private_trait!();
|
||||
}
|
||||
@@ -327,20 +316,19 @@ pub trait OwnedTensorArrayData<I> {
|
||||
}
|
||||
|
||||
pub struct TensorArrayDataParts<I> {
|
||||
pub shape: Vec<i64>,
|
||||
pub shape: Shape,
|
||||
pub ptr: NonNull<I>,
|
||||
pub num_elements: usize,
|
||||
pub guard: Option<Box<dyn Any>>
|
||||
}
|
||||
|
||||
pub trait ToDimensions {
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>>;
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Shape>;
|
||||
}
|
||||
|
||||
macro_rules! impl_to_dimensions {
|
||||
(@inner) => {
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>> {
|
||||
let v: Vec<i64> = self
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Shape> {
|
||||
let v = self
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| {
|
||||
@@ -353,13 +341,17 @@ macro_rules! impl_to_dimensions {
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let sum = element_count(&v);
|
||||
.collect::<Result<Shape>>()?;
|
||||
if let Some(expected_size) = expected_size {
|
||||
if sum != expected_size {
|
||||
if v.num_elements() != expected_size {
|
||||
Err(Error::new_with_code(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)", v, sum, expected_size)
|
||||
format!(
|
||||
"Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)",
|
||||
v,
|
||||
v.num_elements(),
|
||||
expected_size
|
||||
)
|
||||
))
|
||||
} else {
|
||||
Ok(v)
|
||||
@@ -382,21 +374,21 @@ macro_rules! impl_to_dimensions {
|
||||
}
|
||||
|
||||
impl ToDimensions for () {
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>> {
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Shape> {
|
||||
match expected_size {
|
||||
Some(1) | None => Ok(vec![]),
|
||||
Some(1) | None => Ok(Shape::default()),
|
||||
Some(_) => Err(Error::new_with_code(ErrorCode::InvalidArgument, "Expected data to have a length of exactly 1 for scalar shape"))
|
||||
}
|
||||
}
|
||||
}
|
||||
impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec<usize>, for Vec<i32>, for Vec<i64>);
|
||||
impl_to_dimensions!(for Shape, for &[usize], for &[i32], for &[i64], for Vec<usize>, for Vec<i32>, for Vec<i64>);
|
||||
impl_to_dimensions!(<N> for [usize; N], for [i32; N], for [i64; N]);
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &CowArray<'_, T, D> {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -409,8 +401,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &CowArra
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArcArray<T, D> {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -423,8 +415,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArcArray
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &Array<T, D> {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -437,8 +429,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &Array<T
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &mut Array<T, D> {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -455,27 +447,21 @@ impl<T: Clone + 'static, D: Dimension + 'static> OwnedTensorArrayData<T> for Arr
|
||||
if self.is_standard_layout() {
|
||||
// We can avoid the copy here and use the data as is
|
||||
let mut this = Box::new(self);
|
||||
let shape: Vec<i64> = this.shape().iter().map(|d| *d as i64).collect();
|
||||
let shape: Shape = this.shape().iter().map(|d| *d as i64).collect();
|
||||
// SAFETY: ndarrays internally store their pointer as NonNull
|
||||
let ptr = unsafe { NonNull::new_unchecked(this.as_mut_ptr()) };
|
||||
let num_elements = this.len();
|
||||
Ok(TensorArrayDataParts {
|
||||
shape,
|
||||
ptr,
|
||||
num_elements,
|
||||
guard: Some(this)
|
||||
})
|
||||
assert_eq!(this.len(), shape.num_elements());
|
||||
Ok(TensorArrayDataParts { shape, ptr, guard: Some(this) })
|
||||
} else {
|
||||
// Need to do a copy here to get data in to standard layout
|
||||
let mut contiguous_array = self.as_standard_layout().into_owned();
|
||||
let shape: Vec<i64> = contiguous_array.shape().iter().map(|d| *d as i64).collect();
|
||||
let shape: Shape = contiguous_array.shape().iter().map(|d| *d as i64).collect();
|
||||
// SAFETY: ndarrays internally store their pointer as NonNull
|
||||
let ptr = unsafe { NonNull::new_unchecked(contiguous_array.as_mut_ptr()) };
|
||||
let num_elements: usize = contiguous_array.len();
|
||||
assert_eq!(contiguous_array.len(), shape.num_elements());
|
||||
Ok(TensorArrayDataParts {
|
||||
shape,
|
||||
ptr,
|
||||
num_elements,
|
||||
guard: Some(Box::new(contiguous_array))
|
||||
})
|
||||
}
|
||||
@@ -487,8 +473,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> OwnedTensorArrayData<T> for Arr
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayView<'_, T, D> {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -501,8 +487,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayVie
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayViewMut<'_, T, D> {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -515,8 +501,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayVie
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for ArrayViewMut<'_, T, D> {
|
||||
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice_mut()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -529,8 +515,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for Array
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for &mut Array<T, D> {
|
||||
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
|
||||
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
|
||||
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.shape().iter().map(|d| *d as i64).collect();
|
||||
let data = self
|
||||
.as_slice_mut()
|
||||
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
|
||||
@@ -541,7 +527,7 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for &mut
|
||||
}
|
||||
|
||||
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &[T]) {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
Ok((shape, self.1, None))
|
||||
}
|
||||
@@ -550,7 +536,7 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &[T]) {
|
||||
}
|
||||
|
||||
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &mut [T]) {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
Ok((shape, self.1, None))
|
||||
}
|
||||
@@ -559,7 +545,7 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &mut [T]) {
|
||||
}
|
||||
|
||||
impl<T: Clone + 'static, D: ToDimensions> TensorArrayDataMut<T> for (D, &mut [T]) {
|
||||
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
|
||||
fn ref_parts_mut(&mut self) -> Result<(Shape, &mut [T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
Ok((shape, self.1, None))
|
||||
}
|
||||
@@ -572,11 +558,10 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Vec<T>
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
// SAFETY: A `Vec` always has a non-null pointer.
|
||||
let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) };
|
||||
let num_elements: usize = self.1.len();
|
||||
assert_eq!(shape.num_elements(), self.1.len());
|
||||
Ok(TensorArrayDataParts {
|
||||
shape,
|
||||
ptr,
|
||||
num_elements,
|
||||
guard: Some(Box::new(self.1))
|
||||
})
|
||||
}
|
||||
@@ -589,11 +574,10 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
// SAFETY: A `Box` always has a non-null pointer.
|
||||
let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) };
|
||||
let num_elements: usize = self.1.len();
|
||||
assert_eq!(shape.num_elements(), self.1.len());
|
||||
Ok(TensorArrayDataParts {
|
||||
shape,
|
||||
ptr,
|
||||
num_elements,
|
||||
guard: Some(Box::new(self.1))
|
||||
})
|
||||
}
|
||||
@@ -602,7 +586,7 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
|
||||
}
|
||||
|
||||
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
let data = &*self.1;
|
||||
Ok((shape, data, Some(Box::new(self.1.clone()))))
|
||||
@@ -612,7 +596,7 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
|
||||
}
|
||||
|
||||
impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]>>) {
|
||||
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
|
||||
fn ref_parts(&self) -> Result<(Shape, &[T], Option<Box<dyn Any>>)> {
|
||||
let shape = self.0.to_dimensions(Some(self.1.len()))?;
|
||||
let data = &*self.1;
|
||||
Ok((shape, data, Some(Box::new(self.1.clone()))))
|
||||
|
||||
@@ -12,8 +12,7 @@ use crate::{
|
||||
error::{Error, ErrorCode, Result},
|
||||
memory::MemoryInfo,
|
||||
ortsys,
|
||||
tensor::{PrimitiveTensorElementType, TensorElementType},
|
||||
util::element_count,
|
||||
tensor::{PrimitiveTensorElementType, Shape, TensorElementType},
|
||||
value::{Value, ValueType}
|
||||
};
|
||||
|
||||
@@ -49,11 +48,8 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn try_extract_tensor<T: PrimitiveTensorElementType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
|
||||
use ndarray::IntoDimension;
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| {
|
||||
let shape = dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>().into_dimension();
|
||||
Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape, data_ptr(ptr)?.cast::<T>()) })
|
||||
})
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
|
||||
.and_then(|(ptr, shape)| Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape.to_ixdyn(), data_ptr(ptr)?.cast::<T>()) }))
|
||||
}
|
||||
|
||||
/// Attempt to extract the scalar from a tensor of type `T`.
|
||||
@@ -80,11 +76,11 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
///
|
||||
/// [`DynValue`]: crate::value::DynValue
|
||||
pub fn try_extract_scalar<T: PrimitiveTensorElementType + Copy>(&self) -> Result<T> {
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| {
|
||||
if !dimensions.is_empty() {
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, shape)| {
|
||||
if !shape.is_empty() {
|
||||
return Err(Error::new_with_code(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), dimensions.len())
|
||||
format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), shape.len())
|
||||
));
|
||||
}
|
||||
|
||||
@@ -122,14 +118,11 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn try_extract_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
|
||||
use ndarray::IntoDimension;
|
||||
extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| {
|
||||
let shape = dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>().into_dimension();
|
||||
Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape, data_ptr(ptr)?.cast::<T>()) })
|
||||
})
|
||||
extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
|
||||
.and_then(|(ptr, shape)| Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape.to_ixdyn(), data_ptr(ptr)?.cast::<T>()) }))
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an
|
||||
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's shape and an
|
||||
/// immutable view into its data.
|
||||
///
|
||||
/// See also:
|
||||
@@ -145,7 +138,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
///
|
||||
/// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor::<i64>()?;
|
||||
/// assert_eq!(extracted_data, &array);
|
||||
/// assert_eq!(extracted_shape, [5]);
|
||||
/// assert_eq!(**extracted_shape, [5]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -157,12 +150,12 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
///
|
||||
/// [`DynValue`]: crate::value::DynValue
|
||||
pub fn try_extract_raw_tensor<T: PrimitiveTensorElementType>(&self) -> Result<(&[i64], &[T])> {
|
||||
pub fn try_extract_raw_tensor<T: PrimitiveTensorElementType>(&self) -> Result<(&Shape, &[T])> {
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
|
||||
.and_then(|(ptr, dimensions)| Ok((dimensions, unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::<T>(), element_count(dimensions)) })))
|
||||
.and_then(|(ptr, shape)| Ok((shape, unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::<T>(), shape.num_elements()) })))
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a
|
||||
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's shape and a
|
||||
/// mutable view into its data.
|
||||
///
|
||||
/// See also the infallible counterpart, [`Tensor::extract_raw_tensor_mut`], for typed [`Tensor<T>`]s.
|
||||
@@ -175,7 +168,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
///
|
||||
/// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor_mut::<i64>()?;
|
||||
/// assert_eq!(extracted_data, &array);
|
||||
/// assert_eq!(extracted_shape, [5]);
|
||||
/// assert_eq!(**extracted_shape, [5]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -187,9 +180,9 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
///
|
||||
/// [`DynValue`]: crate::value::DynValue
|
||||
pub fn try_extract_raw_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<(&[i64], &mut [T])> {
|
||||
pub fn try_extract_raw_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<(&Shape, &mut [T])> {
|
||||
extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type())
|
||||
.and_then(|(ptr, dimensions)| Ok((dimensions, unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::<T>(), element_count(dimensions)) })))
|
||||
.and_then(|(ptr, shape)| Ok((shape, unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::<T>(), shape.num_elements()) })))
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data into a Rust `ndarray`.
|
||||
@@ -208,15 +201,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn try_extract_string_tensor(&self) -> Result<ndarray::ArrayD<String>> {
|
||||
use ndarray::IntoDimension;
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| {
|
||||
let strings = extract_strings(ptr, dimensions)?;
|
||||
Ok(ndarray::Array::from_shape_vec(dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>().into_dimension(), strings)
|
||||
.expect("Shape extracted from tensor didn't match tensor contents"))
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, shape)| {
|
||||
let strings = extract_strings(ptr, shape)?;
|
||||
Ok(ndarray::Array::from_shape_vec(shape.to_ixdyn(), strings).expect("Shape extracted from tensor didn't match tensor contents"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and
|
||||
/// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's shape and
|
||||
/// an owned `Vec` of its data.
|
||||
///
|
||||
/// ```
|
||||
@@ -227,14 +218,14 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
///
|
||||
/// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?;
|
||||
/// assert_eq!(extracted_data, array);
|
||||
/// assert_eq!(extracted_shape, [2]);
|
||||
/// assert_eq!(**extracted_shape, [2]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn try_extract_raw_string_tensor(&self) -> Result<(&[i64], Vec<String>)> {
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| {
|
||||
let strings = extract_strings(ptr, dimensions)?;
|
||||
Ok((dimensions, strings))
|
||||
pub fn try_extract_raw_string_tensor(&self) -> Result<(&Shape, Vec<String>)> {
|
||||
extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, shape)| {
|
||||
let strings = extract_strings(ptr, shape)?;
|
||||
Ok((shape, strings))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -246,13 +237,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// # let allocator = Allocator::default();
|
||||
/// let tensor = Tensor::<f32>::new(&allocator, [1, 128, 128, 3])?;
|
||||
///
|
||||
/// assert_eq!(tensor.shape(), [1, 128, 128, 3]);
|
||||
/// assert_eq!(**tensor.shape(), [1, 128, 128, 3]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn shape(&self) -> &[i64] {
|
||||
pub fn shape(&self) -> &Shape {
|
||||
match self.dtype() {
|
||||
ValueType::Tensor { dimensions, .. } => dimensions,
|
||||
ValueType::Tensor { shape, .. } => shape,
|
||||
_ => unreachable!()
|
||||
}
|
||||
}
|
||||
@@ -263,9 +254,9 @@ fn extract_tensor<'t>(
|
||||
dtype: &'t ValueType,
|
||||
memory_info: &MemoryInfo,
|
||||
expected_ty: TensorElementType
|
||||
) -> Result<(*mut ort_sys::OrtValue, &'t [i64])> {
|
||||
) -> Result<(*mut ort_sys::OrtValue, &'t Shape)> {
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions, .. } => {
|
||||
ValueType::Tensor { ty, shape, .. } => {
|
||||
if !memory_info.is_cpu_accessible() {
|
||||
return Err(Error::new(format!(
|
||||
"Cannot extract from value on device `{}`, which is not CPU accessible",
|
||||
@@ -274,7 +265,7 @@ fn extract_tensor<'t>(
|
||||
}
|
||||
|
||||
if *ty == expected_ty {
|
||||
Ok((ptr, dimensions))
|
||||
Ok((ptr, shape))
|
||||
} else {
|
||||
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from Tensor<{}>", expected_ty, ty)))
|
||||
}
|
||||
@@ -289,9 +280,8 @@ unsafe fn data_ptr(ptr: *mut ort_sys::OrtValue) -> Result<*mut c_void> {
|
||||
Ok(output_array_ptr)
|
||||
}
|
||||
|
||||
fn extract_strings(ptr: *mut ort_sys::OrtValue, dimensions: &[i64]) -> Result<Vec<String>> {
|
||||
let len = element_count(dimensions);
|
||||
|
||||
fn extract_strings(ptr: *mut ort_sys::OrtValue, shape: &Shape) -> Result<Vec<String>> {
|
||||
let len = shape.num_elements();
|
||||
// Total length of string data, not including \0 suffix
|
||||
let mut total_length = 0;
|
||||
ortsys![unsafe GetStringTensorDataLength(ptr, &mut total_length)?];
|
||||
@@ -368,7 +358,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
self.try_extract_tensor_mut().expect("Failed to extract tensor")
|
||||
}
|
||||
|
||||
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an immutable
|
||||
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's shapes and an immutable
|
||||
/// view into its data.
|
||||
///
|
||||
/// ```
|
||||
@@ -379,15 +369,15 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
///
|
||||
/// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor();
|
||||
/// assert_eq!(extracted_data, &array);
|
||||
/// assert_eq!(extracted_shape, [5]);
|
||||
/// assert_eq!(**extracted_shape, [5]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn extract_raw_tensor(&self) -> (&[i64], &[T]) {
|
||||
pub fn extract_raw_tensor(&self) -> (&Shape, &[T]) {
|
||||
self.try_extract_raw_tensor().expect("Failed to extract tensor")
|
||||
}
|
||||
|
||||
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a mutable view
|
||||
/// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's shapes and a mutable view
|
||||
/// into its data.
|
||||
///
|
||||
/// ```
|
||||
@@ -403,7 +393,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn extract_raw_tensor_mut(&mut self) -> (&[i64], &mut [T]) {
|
||||
pub fn extract_raw_tensor_mut(&mut self) -> (&Shape, &mut [T]) {
|
||||
self.try_extract_raw_tensor_mut().expect("Failed to extract tensor")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
mod create;
|
||||
mod extract;
|
||||
|
||||
use alloc::{string::String, sync::Arc, vec};
|
||||
use alloc::sync::Arc;
|
||||
use core::{
|
||||
fmt::{self, Debug},
|
||||
marker::PhantomData,
|
||||
@@ -16,8 +16,7 @@ use crate::{
|
||||
error::Result,
|
||||
memory::{Allocator, MemoryInfo},
|
||||
ortsys,
|
||||
tensor::{IntoTensorElementType, TensorElementType},
|
||||
util::element_count
|
||||
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
|
||||
};
|
||||
|
||||
pub trait TensorValueTypeMarker: ValueTypeMarker {
|
||||
@@ -93,10 +92,13 @@ impl DynTensor {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl ToDimensions) -> Result<DynTensor> {
|
||||
pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl Into<Shape>) -> Result<DynTensor> {
|
||||
Self::new_inner(allocator, data_type, shape.into())
|
||||
}
|
||||
|
||||
fn new_inner(allocator: &Allocator, data_type: TensorElementType, shape: Shape) -> Result<DynTensor> {
|
||||
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
|
||||
|
||||
let shape = shape.to_dimensions(None)?;
|
||||
let shape_ptr: *const i64 = shape.as_ptr();
|
||||
let shape_len = shape.len();
|
||||
|
||||
@@ -118,7 +120,7 @@ impl DynTensor {
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];
|
||||
|
||||
unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(element_count(&shape))) };
|
||||
unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(shape.num_elements())) };
|
||||
}
|
||||
|
||||
Ok(Value {
|
||||
@@ -126,8 +128,8 @@ impl DynTensor {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
dtype: ValueType::Tensor {
|
||||
ty: data_type,
|
||||
dimensions: shape,
|
||||
dimension_symbols: vec![String::default(); shape_len]
|
||||
shape,
|
||||
dimension_symbols: SymbolicDimensions::empty(shape_len)
|
||||
},
|
||||
drop: true,
|
||||
memory_info: MemoryInfo::from_value(value_ptr),
|
||||
@@ -338,7 +340,7 @@ mod tests {
|
||||
use super::Tensor;
|
||||
use crate::{
|
||||
memory::Allocator,
|
||||
tensor::TensorElementType,
|
||||
tensor::{Shape, SymbolicDimensions, TensorElementType},
|
||||
value::{TensorRef, ValueType}
|
||||
};
|
||||
|
||||
@@ -350,12 +352,12 @@ mod tests {
|
||||
assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Float32));
|
||||
assert_eq!(value.dtype(), &ValueType::Tensor {
|
||||
ty: TensorElementType::Float32,
|
||||
dimensions: vec![v.len() as i64],
|
||||
dimension_symbols: vec![String::default()]
|
||||
shape: Shape::new([v.len() as i64]),
|
||||
dimension_symbols: SymbolicDimensions::empty(1)
|
||||
});
|
||||
|
||||
let (shape, data) = value.extract_raw_tensor();
|
||||
assert_eq!(shape, vec![v.len() as i64]);
|
||||
assert_eq!(&**shape, [v.len() as i64]);
|
||||
assert_eq!(data, &v);
|
||||
|
||||
Ok(())
|
||||
@@ -411,7 +413,7 @@ mod tests {
|
||||
|
||||
let value = Tensor::from_string_array((vec![v.len() as i64], &*v))?;
|
||||
let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?;
|
||||
assert_eq!(extracted_shape, [v.len() as i64]);
|
||||
assert_eq!(&**extracted_shape, [v.len() as i64]);
|
||||
assert_eq!(extracted_view, v);
|
||||
|
||||
Ok(())
|
||||
@@ -437,7 +439,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_tensor_index() -> crate::Result<()> {
|
||||
let mut tensor = Tensor::new(&Allocator::default(), [1, 3, 224, 224])?;
|
||||
let mut tensor = Tensor::new(&Allocator::default(), Shape::new([1, 3, 224, 224]))?;
|
||||
|
||||
tensor[[0, 2, 42, 42]] = 1.0;
|
||||
assert_eq!(tensor[[0, 2, 42, 42]], 1.0);
|
||||
|
||||
@@ -1,39 +1,45 @@
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
string::{String, ToString},
|
||||
vec,
|
||||
vec::Vec
|
||||
};
|
||||
use core::{
|
||||
ffi::{CStr, c_char},
|
||||
fmt, ptr
|
||||
string::{String, ToString}
|
||||
};
|
||||
use core::{ffi::CStr, fmt, ptr};
|
||||
|
||||
use crate::{ortsys, tensor::TensorElementType, util::with_cstr_ptr_array};
|
||||
use smallvec::{SmallVec, smallvec};
|
||||
|
||||
use crate::{
|
||||
ortsys,
|
||||
tensor::{Shape, SymbolicDimensions, TensorElementType},
|
||||
util::with_cstr_ptr_array
|
||||
};
|
||||
|
||||
/// The type of a [`Value`][super::Value], or a session input/output.
|
||||
///
|
||||
/// ```
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType};
|
||||
/// # use ort::{session::Session, tensor::{Shape, SymbolicDimensions}, value::{ValueType, Tensor}, tensor::TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// // `ValueType`s can be obtained from session inputs/outputs:
|
||||
/// let input = &session.inputs[0];
|
||||
/// assert_eq!(input.input_type, ValueType::Tensor {
|
||||
/// ty: TensorElementType::Float32,
|
||||
/// // Our model has 3 dynamic dimensions, represented by -1
|
||||
/// dimensions: vec![-1, -1, -1, 3],
|
||||
/// // Our model's input has 3 dynamic dimensions, represented by -1
|
||||
/// shape: Shape::new([-1, -1, -1, 3]),
|
||||
/// // Dynamic dimensions may also have names.
|
||||
/// dimension_symbols: vec!["unk__31".to_string(), "unk__32".to_string(), "unk__33".to_string(), String::default()]
|
||||
/// dimension_symbols: SymbolicDimensions::new([
|
||||
/// "unk__31".to_string(),
|
||||
/// "unk__32".to_string(),
|
||||
/// "unk__33".to_string(),
|
||||
/// String::default()
|
||||
/// ])
|
||||
/// });
|
||||
///
|
||||
/// // ...or by `Value`s created in Rust or output by a session.
|
||||
/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?;
|
||||
/// assert_eq!(value.dtype(), &ValueType::Tensor {
|
||||
/// ty: TensorElementType::Int64,
|
||||
/// dimensions: vec![5],
|
||||
/// dimension_symbols: vec![String::default()]
|
||||
/// shape: Shape::new([5]),
|
||||
/// dimension_symbols: SymbolicDimensions::new([String::default()])
|
||||
/// });
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
@@ -44,15 +50,16 @@ pub enum ValueType {
|
||||
Tensor {
|
||||
/// Element type of the tensor.
|
||||
ty: TensorElementType,
|
||||
/// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an
|
||||
/// Shape of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an
|
||||
/// [`Input`]/[`Output`]), the dimension will be `-1`.
|
||||
///
|
||||
/// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions.
|
||||
/// Actual tensor values (i.e. not [`Input`] or [`Output`] definitions), which have a known dimension, will
|
||||
/// always have non-negative dimensions.
|
||||
///
|
||||
/// [`Input`]: crate::session::Input
|
||||
/// [`Output`]: crate::session::Output
|
||||
dimensions: Vec<i64>,
|
||||
dimension_symbols: Vec<String>
|
||||
shape: Shape,
|
||||
dimension_symbols: SymbolicDimensions
|
||||
},
|
||||
/// A sequence (vector) of other `Value`s.
|
||||
///
|
||||
@@ -136,11 +143,11 @@ impl ValueType {
|
||||
|
||||
pub(crate) fn to_tensor_type_info(&self) -> Option<*mut ort_sys::OrtTensorTypeAndShapeInfo> {
|
||||
match self {
|
||||
Self::Tensor { ty, dimensions, dimension_symbols } => {
|
||||
Self::Tensor { ty, shape, dimension_symbols } => {
|
||||
let mut info_ptr = ptr::null_mut();
|
||||
ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr).expect("infallible")];
|
||||
ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into()).expect("infallible")];
|
||||
ortsys![unsafe SetDimensions(info_ptr, dimensions.as_ptr(), dimensions.len()).expect("infallible")];
|
||||
ortsys![unsafe SetDimensions(info_ptr, shape.as_ptr(), shape.len()).expect("infallible")];
|
||||
with_cstr_ptr_array(dimension_symbols, &|ptrs| {
|
||||
ortsys![unsafe SetSymbolicDimensions(info_ptr, ptrs.as_ptr().cast_mut(), dimension_symbols.len()).expect("infallible")];
|
||||
Ok(())
|
||||
@@ -152,20 +159,22 @@ impl ValueType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map.
|
||||
/// Returns the shape of this value type if it is a tensor, or `None` if it is a sequence or map.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::value::Tensor;
|
||||
/// # use ort::value::{Tensor, DynValue};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?;
|
||||
/// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5]));
|
||||
/// let value: DynValue = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?.into_dyn();
|
||||
///
|
||||
/// let shape = value.dtype().tensor_shape().unwrap();
|
||||
/// assert_eq!(**shape, [5]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn tensor_dimensions(&self) -> Option<&Vec<i64>> {
|
||||
pub fn tensor_shape(&self) -> Option<&Shape> {
|
||||
match self {
|
||||
ValueType::Tensor { dimensions, .. } => Some(dimensions),
|
||||
ValueType::Tensor { shape, .. } => Some(shape),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
@@ -213,9 +222,9 @@ impl ValueType {
|
||||
impl fmt::Display for ValueType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ValueType::Tensor { ty, dimensions, dimension_symbols } => {
|
||||
ValueType::Tensor { ty, shape, dimension_symbols } => {
|
||||
write!(f, "Tensor<{ty}>(")?;
|
||||
for (i, dimension) in dimensions.iter().copied().enumerate() {
|
||||
for (i, dimension) in shape.iter().copied().enumerate() {
|
||||
if dimension == -1 {
|
||||
let sym = &dimension_symbols[i];
|
||||
if sym.is_empty() {
|
||||
@@ -226,7 +235,7 @@ impl fmt::Display for ValueType {
|
||||
} else {
|
||||
dimension.fmt(f)?;
|
||||
}
|
||||
if i != dimensions.len() - 1 {
|
||||
if i != shape.len() - 1 {
|
||||
f.write_str(", ")?;
|
||||
}
|
||||
}
|
||||
@@ -248,10 +257,10 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(info_ptr, &mut num_dims).expect("infallible")];
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims];
|
||||
let mut node_dims = Shape::empty(num_dims);
|
||||
ortsys![unsafe GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims).expect("infallible")];
|
||||
|
||||
let mut symbolic_dims: Vec<*const c_char> = vec![ptr::null(); num_dims];
|
||||
let mut symbolic_dims: SmallVec<_, 4> = smallvec![ptr::null(); num_dims];
|
||||
ortsys![unsafe GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims).expect("infallible")];
|
||||
|
||||
let dimension_symbols = symbolic_dims
|
||||
@@ -261,7 +270,7 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys
|
||||
|
||||
ValueType::Tensor {
|
||||
ty: type_sys.into(),
|
||||
dimensions: node_dims,
|
||||
shape: node_dims,
|
||||
dimension_symbols
|
||||
}
|
||||
}
|
||||
@@ -288,14 +297,17 @@ unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeIn
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ValueType;
|
||||
use crate::{ortsys, tensor::TensorElementType};
|
||||
use crate::{
|
||||
ortsys,
|
||||
tensor::{Shape, SymbolicDimensions, TensorElementType}
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_to_from_tensor_info() -> crate::Result<()> {
|
||||
let ty = ValueType::Tensor {
|
||||
ty: TensorElementType::Float32,
|
||||
dimensions: vec![-1, 32, 4, 32],
|
||||
dimension_symbols: vec!["d1".to_string(), String::default(), String::default(), String::default()]
|
||||
shape: Shape::new([-1, 32, 4, 32]),
|
||||
dimension_symbols: SymbolicDimensions::new(["d1".to_string(), String::default(), String::default(), String::default()])
|
||||
};
|
||||
let ty_ptr = ty.to_tensor_type_info().expect("");
|
||||
let ty_d = unsafe { super::extract_data_type_from_tensor_info(ty_ptr) };
|
||||
|
||||
@@ -26,11 +26,11 @@ fn mnist_5() -> ort::Result<()> {
|
||||
assert_eq!(metadata.name()?, "CNTKGraph");
|
||||
assert_eq!(metadata.producer()?, "CNTK");
|
||||
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
|
||||
let input0_shape = session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type");
|
||||
let output0_shape = session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type");
|
||||
|
||||
assert_eq!(input0_shape, &[1, 1, 28, 28]);
|
||||
assert_eq!(output0_shape, &[1, 10]);
|
||||
assert_eq!(&**input0_shape, &[1, 1, 28, 28]);
|
||||
assert_eq!(&**output0_shape, &[1, 10]);
|
||||
|
||||
input0_shape
|
||||
};
|
||||
|
||||
@@ -33,11 +33,11 @@ fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
assert_eq!(metadata.name()?, "main_graph");
|
||||
assert_eq!(metadata.producer()?, "pytorch");
|
||||
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
|
||||
let input0_shape = session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type");
|
||||
let output0_shape = session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type");
|
||||
|
||||
assert_eq!(input0_shape, &[1, 3, 224, 224]);
|
||||
assert_eq!(output0_shape, &[1, 1000]);
|
||||
assert_eq!(&**input0_shape, [1, 3, 224, 224]);
|
||||
assert_eq!(&**output0_shape, [1, 1000]);
|
||||
|
||||
input0_shape
|
||||
};
|
||||
|
||||
@@ -63,8 +63,8 @@ fn upsample() -> ort::Result<()> {
|
||||
assert_eq!(metadata.name()?, "tf2onnx");
|
||||
assert_eq!(metadata.producer()?, "tf2onnx");
|
||||
|
||||
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(&**session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"), [-1, -1, -1, 3]);
|
||||
assert_eq!(&**session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"), [-1, -1, -1, 3]);
|
||||
}
|
||||
|
||||
// Load image, converting to RGB format
|
||||
@@ -101,8 +101,8 @@ fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
.commit_from_memory_directly(&session_data) // Zero-copy.
|
||||
.expect("Could not read model from memory");
|
||||
|
||||
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(&**session.inputs[0].input_type.tensor_shape().expect("input0 to be a tensor type"), [-1, -1, -1, 3]);
|
||||
assert_eq!(&**session.outputs[0].output_type.tensor_shape().expect("output0 to be a tensor type"), [-1, -1, -1, 3]);
|
||||
|
||||
// Load image, converting to RGB format
|
||||
let image_buffer = load_input_image(IMAGE_TO_LOAD);
|
||||
|
||||
Reference in New Issue
Block a user