refactor: merge tensor module into value

This commit is contained in:
Carson M.
2026-01-15 02:03:23 -06:00
parent c959e1fde5
commit 25a6760783
25 changed files with 75 additions and 80 deletions

View File

@@ -4,7 +4,7 @@ use image::{ImageBuffer, Luma, Pixel, imageops::FilterType};
use ort::{
inputs,
session::{Session, builder::GraphOptimizationLevel},
tensor::ArrayExtensions,
util::ArrayExt,
value::TensorRef
};

View File

@@ -7,8 +7,7 @@ use ort::{
ep,
memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType},
session::Session,
tensor::Shape,
value::TensorRefMut
value::{Shape, TensorRefMut}
};
use show_image::{AsImageView, WindowOptions, event};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -5,8 +5,7 @@ use ort::{
kernel::{Kernel, KernelAttributes, KernelContext}
},
session::Session,
tensor::TensorElementType,
value::Tensor
value::{Tensor, TensorElementType}
};
struct CustomOpOne;

View File

@@ -5,8 +5,7 @@ use crate::{
inputs,
memory::Allocator,
session::builder::SessionBuilder,
tensor::{Shape, SymbolicDimensions, TensorElementType},
value::{Tensor, ValueType}
value::{Shape, SymbolicDimensions, Tensor, TensorElementType, ValueType}
};
#[test]

View File

@@ -36,7 +36,6 @@ pub mod logging;
pub mod memory;
pub mod operator;
pub mod session;
pub mod tensor;
#[cfg(feature = "training")]
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
pub mod training;

View File

@@ -1,4 +1,4 @@
use crate::{memory::MemoryType, tensor::TensorElementType};
use crate::{memory::MemoryType, value::TensorElementType};
#[repr(i32)]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]

View File

@@ -12,9 +12,8 @@ use crate::{
logging::Logger,
memory::{Allocator, MemoryInfo, MemoryType},
ortsys,
tensor::Shape,
util::with_cstr,
value::{DowncastableTarget, DynValue, Outlet, Value, ValueRef, ValueRefMut, ValueType}
value::{DowncastableTarget, DynValue, Outlet, Shape, Value, ValueRef, ValueRefMut, ValueType}
};
pub trait Kernel {

View File

@@ -6,8 +6,7 @@ use crate::{
kernel::{Kernel, KernelAttributes, KernelContext}
},
session::Session,
tensor::TensorElementType,
value::Tensor
value::{Tensor, TensorElementType}
};
struct CustomOpOne;

View File

@@ -194,7 +194,7 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{run_options::RunOptions, Session}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}};
/// # use ort::{session::{run_options::RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}};
/// # fn main() -> ort::Result<()> {
/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
@@ -223,7 +223,7 @@ impl Session {
/// ```no_run
/// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough
/// # use std::sync::Arc;
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType};
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef, TensorElementType}};
/// # fn main() -> ort::Result<()> {
/// # let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// # let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
@@ -385,7 +385,7 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType};
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef, TensorElementType}};
/// # fn main() -> ort::Result<()> { tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async {
/// let mut session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
@@ -501,7 +501,7 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType};
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef, TensorElementType}};
/// # fn main() -> ort::Result<()> { tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async {
/// let mut session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
@@ -614,7 +614,7 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}};
/// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, value::{Value, ValueType, TensorRef, TensorElementType}};
/// # fn main() -> ort::Result<()> {
/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// session.set_workload_type(WorkloadType::Efficient)?;

View File

@@ -5,7 +5,7 @@ pub mod mnist {
use crate::{
Result,
tensor::ArrayExtensions,
util::ArrayExt,
value::{TensorValueTypeMarker, Value}
};

View File

@@ -15,9 +15,8 @@ use crate::{
memory::Allocator,
ortsys,
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
tensor::IntoTensorElementType,
util::{char_p_to_string, with_cstr_ptr_array},
value::{Tensor, Value}
value::{IntoTensorElementType, Tensor, Value}
};
#[derive(Debug)]

View File

@@ -24,6 +24,11 @@ pub(crate) use self::{
once_lock::OnceLock,
stack::*
};
#[cfg(feature = "ndarray")]
mod ndarray;
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub use self::ndarray::ArrayExt;
/// Preloads the dynamic library at the given `path`.
///

View File

@@ -1,10 +1,8 @@
//! Helper traits to extend [`ndarray`] functionality.
/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
/// with useful tensor operations.
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub trait ArrayExtensions<S, T, D> {
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub trait ArrayExt<S, T, D> {
/// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis.
fn softmax(&self, axis: ndarray::Axis) -> ndarray::Array<T, D>
where
@@ -14,9 +12,9 @@ pub trait ArrayExtensions<S, T, D> {
T: ndarray::NdFloat + core::ops::SubAssign + core::ops::DivAssign;
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<S, T, D> ArrayExtensions<S, T, D> for ndarray::ArrayBase<S, D>
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<S, T, D> ArrayExt<S, T, D> for ndarray::ArrayBase<S, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,

View File

@@ -13,14 +13,13 @@ use std::collections::HashMap;
use super::{
DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker,
impl_tensor::{DynTensor, Tensor}
impl_tensor::{DynTensor, IntoTensorElementType, PrimitiveTensorElementType, Tensor, TensorElementType}
};
use crate::{
AsPointer, ErrorCode,
error::{Error, Result},
memory::Allocator,
ortsys,
tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType}
ortsys
};
pub trait MapValueTypeMarker: ValueTypeMarker {

View File

@@ -109,8 +109,9 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<Se
/// This `Value<T>` must be either a [`Tensor`] or [`Map`].
///
/// ```
/// # use ort::value::{Sequence, Tensor};
/// # use ort::{memory:: Allocator, value::{Sequence, Tensor}};
/// # fn main() -> ort::Result<()> {
/// # let allocator = Allocator::default();
/// let tensor1 = Tensor::<f32>::new(&allocator, [1_usize, 128, 128, 3])?;
/// let tensor2 = Tensor::<f32>::new(&allocator, [1_usize, 224, 224, 3])?;
/// let value = Sequence::new([tensor1, tensor2])?;

View File

@@ -11,13 +11,12 @@ use core::{
#[cfg(feature = "ndarray")]
use ndarray::{ArcArray, Array, ArrayView, ArrayViewMut, CowArray, Dimension};
use super::{DynTensor, Tensor, TensorRef, TensorRefMut};
use super::{DynTensor, PrimitiveTensorElementType, Shape, SymbolicDimensions, Tensor, TensorElementType, TensorRef, TensorRefMut, Utf8Data};
use crate::{
AsPointer,
error::{Error, ErrorCode, Result},
memory::{Allocator, MemoryInfo},
ortsys,
tensor::{PrimitiveTensorElementType, Shape, SymbolicDimensions, TensorElementType, Utf8Data},
value::{Value, ValueInner, ValueType}
};

View File

@@ -11,12 +11,11 @@ use core::{
slice
};
use super::{DynTensor, Tensor, TensorValueTypeMarker};
use super::{DynTensor, PrimitiveTensorElementType, Shape, Tensor, TensorElementType, TensorValueTypeMarker};
use crate::{
AsPointer,
error::{Error, ErrorCode, Result},
ortsys,
tensor::{PrimitiveTensorElementType, Shape, TensorElementType},
value::{Value, ValueType}
};

View File

@@ -1,7 +1,9 @@
#[cfg(not(target_arch = "wasm32"))] // `ort-web` does not support synchronous `Run`, which `.clone()` depends on.
mod copy;
mod create;
mod element;
mod extract;
mod shape;
use alloc::sync::Arc;
use core::{
@@ -11,14 +13,17 @@ use core::{
ptr::{self}
};
pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, ToShape};
pub use self::{
create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, ToShape},
element::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data},
shape::{Shape, SymbolicDimensions}
};
use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
use crate::{
AsPointer,
error::Result,
memory::{Allocator, MemoryInfo},
ortsys,
tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType}
ortsys
};
pub trait TensorValueTypeMarker: ValueTypeMarker {
@@ -86,7 +91,7 @@ impl DynTensor {
/// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned
/// (CPU) memory for use with CUDA:
/// ```no_run
/// # use ort::{memory::{Allocator, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}, session::Session, tensor::TensorElementType, value::DynTensor};
/// # use ort::{memory::{Allocator, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}, session::Session, value::{DynTensor, TensorElementType}};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
@@ -345,10 +350,9 @@ mod tests {
#[cfg(feature = "ndarray")]
use ndarray::{ArcArray1, Array1, CowArray};
use super::Tensor;
use super::{Shape, SymbolicDimensions, Tensor, TensorElementType};
use crate::{
memory::Allocator,
tensor::{Shape, SymbolicDimensions, TensorElementType},
value::{TensorRef, ValueType}
};

View File

@@ -1,9 +1,5 @@
//! Traits and types related to [`Tensor`](crate::value::Tensor)s.
#[cfg(feature = "ndarray")]
mod ndarray;
mod types;
use alloc::{string::String, vec::Vec};
use core::{
fmt,
@@ -12,10 +8,6 @@ use core::{
use smallvec::{SmallVec, smallvec};
#[cfg(all(feature = "ndarray", feature = "std"))]
pub use self::ndarray::ArrayExtensions;
pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};
#[derive(Default, Clone, PartialEq, Eq)]
pub struct Shape {
inner: SmallVec<[i64; 4]>

View File

@@ -37,8 +37,9 @@ pub use self::{
DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker
},
impl_tensor::{
DefiniteTensorValueTypeMarker, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, OwnedTensorArrayData, Tensor, TensorArrayData,
TensorArrayDataMut, TensorArrayDataParts, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, ToShape
DefiniteTensorValueTypeMarker, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, IntoTensorElementType, OwnedTensorArrayData,
PrimitiveTensorElementType, Shape, SymbolicDimensions, Tensor, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, TensorElementType, TensorRef,
TensorRefMut, TensorValueType, TensorValueTypeMarker, ToShape, Utf8Data
},
r#type::{Outlet, ValueType}
};
@@ -481,7 +482,6 @@ impl<Type: ValueTypeMarker + ?Sized> AsPointer for Value<Type> {
#[cfg(test)]
mod tests {
use super::{DynTensorValueType, Map, Sequence, Tensor, TensorRef, TensorRefMut, TensorValueType};
use crate::memory::Allocator;
#[test]
fn test_casting_tensor() -> crate::Result<()> {

View File

@@ -12,20 +12,22 @@ use smallvec::{SmallVec, smallvec};
use crate::{
Result, ortsys,
tensor::{Shape, SymbolicDimensions, TensorElementType},
util::{self, run_on_drop, with_cstr, with_cstr_ptr_array}
util::{self, run_on_drop, with_cstr, with_cstr_ptr_array},
value::{Shape, SymbolicDimensions, TensorElementType}
};
/// The type of a [`Value`][super::Value], or a session input/output.
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::Session, tensor::{Shape, SymbolicDimensions}, value::{ValueType, Tensor}, tensor::TensorElementType};
/// # use ort::{session::Session, value::{ValueType, Tensor, Shape, SymbolicDimensions, TensorElementType}};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// // `ValueType`s can be obtained from session inputs/outputs:
/// let input = &session.inputs()[0];
/// assert_eq!(input.dtype(), &ValueType::Tensor {
/// assert_eq!(
/// input.dtype(),
/// &ValueType::Tensor {
/// ty: TensorElementType::Float32,
/// // Our model's input has 3 dynamic dimensions, represented by -1
/// shape: Shape::new([-1, -1, -1, 3]),
@@ -36,15 +38,19 @@ use crate::{
/// "unk__33".to_string(),
/// String::default()
/// ])
/// });
/// }
/// );
///
/// // ...or by `Value`s created in Rust or output by a session.
/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?;
/// assert_eq!(value.dtype(), &ValueType::Tensor {
/// assert_eq!(
/// value.dtype(),
/// &ValueType::Tensor {
/// ty: TensorElementType::Int64,
/// shape: Shape::new([5]),
/// dimension_symbols: SymbolicDimensions::new([String::default()])
/// });
/// }
/// );
/// # Ok(())
/// # }
/// ```
@@ -215,7 +221,7 @@ impl ValueType {
/// Returns the element type of this value type if it is a tensor, or `None` if it is a sequence or map.
///
/// ```
/// # use ort::{tensor::TensorElementType, value::Tensor};
/// # use ort::value::{Tensor, TensorElementType};
/// # fn main() -> ort::Result<()> {
/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?;
/// assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Int64));
@@ -374,7 +380,7 @@ mod tests {
use super::ValueType;
use crate::{
ortsys,
tensor::{Shape, SymbolicDimensions, TensorElementType}
value::{Shape, SymbolicDimensions, TensorElementType}
};
#[test]

View File

@@ -8,8 +8,7 @@ use ort::{
kernel::{Kernel, KernelAttributes, KernelContext}
},
session::{RunOptions, Session},
tensor::TensorElementType,
value::Tensor
value::{Tensor, TensorElementType}
};
struct CustomOpOne;

View File

@@ -4,7 +4,7 @@ use image::{ImageBuffer, Luma, Pixel, imageops::FilterType};
use ort::{
inputs,
session::{Session, builder::GraphOptimizationLevel},
tensor::ArrayExtensions,
util::ArrayExt,
value::TensorRef
};

View File

@@ -9,7 +9,7 @@ use ndarray::s;
use ort::{
Error, inputs,
session::{Session, builder::GraphOptimizationLevel},
tensor::ArrayExtensions,
util::ArrayExt,
value::TensorRef
};