From 25a6760783077d3c016f07ab71b95885db84e5ae Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Thu, 15 Jan 2026 02:03:23 -0600 Subject: [PATCH] refactor: merge `tensor` module into `value` --- backends/tract/tests/session.rs | 2 +- examples/cudarc/cudarc.rs | 3 +- examples/custom-ops/custom-ops.rs | 3 +- src/editor/tests.rs | 3 +- src/lib.rs | 1 - src/operator/io.rs | 2 +- src/operator/kernel.rs | 3 +- src/operator/tests.rs | 3 +- src/session/mod.rs | 10 ++-- src/test_util.rs | 2 +- src/training/trainer.rs | 3 +- src/util/mod.rs | 5 ++ src/{tensor => util}/ndarray.rs | 14 +++--- src/value/impl_map.rs | 5 +- src/value/impl_sequence.rs | 3 +- src/value/impl_tensor/create.rs | 3 +- .../types.rs => value/impl_tensor/element.rs} | 0 src/value/impl_tensor/extract.rs | 3 +- src/value/impl_tensor/mod.rs | 16 +++--- .../mod.rs => value/impl_tensor/shape.rs} | 8 --- src/value/mod.rs | 6 +-- src/value/type.rs | 50 +++++++++++-------- tests/leak-check/main.rs | 3 +- tests/mnist.rs | 2 +- tests/squeezenet.rs | 2 +- 25 files changed, 75 insertions(+), 80 deletions(-) rename src/{tensor => util}/ndarray.rs (91%) rename src/{tensor/types.rs => value/impl_tensor/element.rs} (100%) rename src/{tensor/mod.rs => value/impl_tensor/shape.rs} (92%) diff --git a/backends/tract/tests/session.rs b/backends/tract/tests/session.rs index 5dc2aed..b839090 100644 --- a/backends/tract/tests/session.rs +++ b/backends/tract/tests/session.rs @@ -4,7 +4,7 @@ use image::{ImageBuffer, Luma, Pixel, imageops::FilterType}; use ort::{ inputs, session::{Session, builder::GraphOptimizationLevel}, - tensor::ArrayExtensions, + util::ArrayExt, value::TensorRef }; diff --git a/examples/cudarc/cudarc.rs b/examples/cudarc/cudarc.rs index 6657057..88daa24 100644 --- a/examples/cudarc/cudarc.rs +++ b/examples/cudarc/cudarc.rs @@ -7,8 +7,7 @@ use ort::{ ep, memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType}, session::Session, - tensor::Shape, - value::TensorRefMut + value::{Shape, TensorRefMut} }; use show_image::{AsImageView, WindowOptions, event}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/examples/custom-ops/custom-ops.rs b/examples/custom-ops/custom-ops.rs index f2f32b7..bfe6b03 100644 --- a/examples/custom-ops/custom-ops.rs +++ b/examples/custom-ops/custom-ops.rs @@ -5,8 +5,7 @@ use ort::{ kernel::{Kernel, KernelAttributes, KernelContext} }, session::Session, - tensor::TensorElementType, - value::Tensor + value::{Tensor, TensorElementType} }; struct CustomOpOne; diff --git a/src/editor/tests.rs b/src/editor/tests.rs index ee04b2d..993c95d 100644 --- a/src/editor/tests.rs +++ b/src/editor/tests.rs @@ -5,8 +5,7 @@ use crate::{ inputs, memory::Allocator, session::builder::SessionBuilder, - tensor::{Shape, SymbolicDimensions, TensorElementType}, - value::{Tensor, ValueType} + value::{Shape, SymbolicDimensions, Tensor, TensorElementType, ValueType} }; #[test] diff --git a/src/lib.rs b/src/lib.rs index 652cc2d..86456e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,7 +36,6 @@ pub mod logging; pub mod memory; pub mod operator; pub mod session; -pub mod tensor; #[cfg(feature = "training")] #[cfg_attr(docsrs, doc(cfg(feature = "training")))] pub mod training; diff --git a/src/operator/io.rs b/src/operator/io.rs index 16d0e93..a9281e5 100644 --- a/src/operator/io.rs +++ b/src/operator/io.rs @@ -1,4 +1,4 @@ -use crate::{memory::MemoryType, tensor::TensorElementType}; +use crate::{memory::MemoryType, value::TensorElementType}; #[repr(i32)] #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 8c07129..48997ea 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -12,9 +12,8 @@ use crate::{ logging::Logger, memory::{Allocator, MemoryInfo, MemoryType}, ortsys, - tensor::Shape, util::with_cstr, - value::{DowncastableTarget, DynValue, Outlet, Value, ValueRef, ValueRefMut, ValueType} + value::{DowncastableTarget, DynValue, Outlet, Shape, Value, ValueRef, ValueRefMut, ValueType} }; pub trait Kernel { diff --git a/src/operator/tests.rs b/src/operator/tests.rs index f59f483..5cd61b8 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -6,8 +6,7 @@ use crate::{ kernel::{Kernel, KernelAttributes, KernelContext} }, session::Session, - tensor::TensorElementType, - value::Tensor + value::{Tensor, TensorElementType} }; struct CustomOpOne; diff --git a/src/session/mod.rs b/src/session/mod.rs index 67d29b8..ce04382 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -194,7 +194,7 @@ impl Session { /// /// ``` /// # use std::sync::Arc; - /// # use ort::{session::{run_options::RunOptions, Session}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}}; + /// # use ort::{session::{run_options::RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}}; /// # fn main() -> ort::Result<()> { /// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); @@ -223,7 +223,7 @@ impl Session { /// ```no_run /// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough /// # use std::sync::Arc; - /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType}; + /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef, TensorElementType}}; /// # fn main() -> ort::Result<()> { /// # let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; @@ -385,7 +385,7 @@ impl Session { /// /// ``` /// # use std::sync::Arc; - /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType}; + /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef, TensorElementType}}; /// # fn main() -> ort::Result<()> { tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async { /// let mut session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); @@ -501,7 +501,7 @@ impl Session { /// /// ``` /// # use std::sync::Arc; - /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType}; + /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef, TensorElementType}}; /// # fn main() -> ort::Result<()> { tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async { /// let mut session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); @@ -614,7 +614,7 @@ impl Session { /// /// ``` /// # use std::sync::Arc; - /// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}}; + /// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, value::{Value, ValueType, TensorRef, TensorElementType}}; /// # fn main() -> ort::Result<()> { /// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// session.set_workload_type(WorkloadType::Efficient)?; diff --git a/src/test_util.rs b/src/test_util.rs index 96c7bd5..9420057 100644 --- a/src/test_util.rs +++ b/src/test_util.rs @@ -5,7 +5,7 @@ pub mod mnist { use crate::{ Result, - tensor::ArrayExtensions, + util::ArrayExt, value::{TensorValueTypeMarker, Value} }; diff --git a/src/training/trainer.rs b/src/training/trainer.rs index d9913ce..333508e 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -15,9 +15,8 @@ use crate::{ memory::Allocator, ortsys, session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder}, - tensor::IntoTensorElementType, util::{char_p_to_string, with_cstr_ptr_array}, - value::{Tensor, Value} + value::{IntoTensorElementType, Tensor, Value} }; #[derive(Debug)] diff --git a/src/util/mod.rs b/src/util/mod.rs index c2d6cf8..9b711f4 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -24,6 +24,11 @@ pub(crate) use self::{ once_lock::OnceLock, stack::* }; +#[cfg(feature = "ndarray")] +mod ndarray; +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +pub use self::ndarray::ArrayExt; /// Preloads the dynamic library at the given `path`. /// diff --git a/src/tensor/ndarray.rs b/src/util/ndarray.rs similarity index 91% rename from src/tensor/ndarray.rs rename to src/util/ndarray.rs index d056005..cb61de1 100644 --- a/src/tensor/ndarray.rs +++ b/src/util/ndarray.rs @@ -1,10 +1,8 @@ -//! Helper traits to extend [`ndarray`] functionality. - /// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) /// with useful tensor operations. -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -pub trait ArrayExtensions { +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +pub trait ArrayExt { /// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis. fn softmax(&self, axis: ndarray::Axis) -> ndarray::Array where @@ -14,9 +12,9 @@ pub trait ArrayExtensions { T: ndarray::NdFloat + core::ops::SubAssign + core::ops::DivAssign; } -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl ArrayExtensions for ndarray::ArrayBase +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl ArrayExt for ndarray::ArrayBase where D: ndarray::RemoveAxis, S: ndarray::RawData + ndarray::Data + ndarray::RawData, diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index ccf53c8..09b6511 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -13,14 +13,13 @@ use std::collections::HashMap; use super::{ DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker, - impl_tensor::{DynTensor, Tensor} + impl_tensor::{DynTensor, IntoTensorElementType, PrimitiveTensorElementType, Tensor, TensorElementType} }; use crate::{ AsPointer, ErrorCode, error::{Error, Result}, memory::Allocator, - ortsys, - tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType} + ortsys }; pub trait MapValueTypeMarker: ValueTypeMarker { diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index 46b1098..fd3c9c3 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -109,8 +109,9 @@ impl Value` must be either a [`Tensor`] or [`Map`]. /// /// ``` - /// # use ort::value::{Sequence, Tensor}; + /// # use ort::{memory:: Allocator, value::{Sequence, Tensor}}; /// # fn main() -> ort::Result<()> { + /// # let allocator = Allocator::default(); /// let tensor1 = Tensor::::new(&allocator, [1_usize, 128, 128, 3])?; /// let tensor2 = Tensor::::new(&allocator, [1_usize, 224, 224, 3])?; /// let value = Sequence::new([tensor1, tensor2])?; diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index d1eadbe..2d03efa 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -11,13 +11,12 @@ use core::{ #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, ArrayViewMut, CowArray, Dimension}; -use super::{DynTensor, Tensor, TensorRef, TensorRefMut}; +use super::{DynTensor, PrimitiveTensorElementType, Shape, SymbolicDimensions, Tensor, TensorElementType, TensorRef, TensorRefMut, Utf8Data}; use crate::{ AsPointer, error::{Error, ErrorCode, Result}, memory::{Allocator, MemoryInfo}, ortsys, - tensor::{PrimitiveTensorElementType, Shape, SymbolicDimensions, TensorElementType, Utf8Data}, value::{Value, ValueInner, ValueType} }; diff --git a/src/tensor/types.rs b/src/value/impl_tensor/element.rs similarity index 100% rename from src/tensor/types.rs rename to src/value/impl_tensor/element.rs diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index e704545..ee0f697 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -11,12 +11,11 @@ use core::{ slice }; -use super::{DynTensor, Tensor, TensorValueTypeMarker}; +use super::{DynTensor, PrimitiveTensorElementType, Shape, Tensor, TensorElementType, TensorValueTypeMarker}; use crate::{ AsPointer, error::{Error, ErrorCode, Result}, ortsys, - tensor::{PrimitiveTensorElementType, Shape, TensorElementType}, value::{Value, ValueType} }; diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index eed85c4..d043db9 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -1,7 +1,9 @@ #[cfg(not(target_arch = "wasm32"))] // `ort-web` does not support synchronous `Run`, which `.clone()` depends on. mod copy; mod create; +mod element; mod extract; +mod shape; use alloc::sync::Arc; use core::{ @@ -11,14 +13,17 @@ use core::{ ptr::{self} }; -pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, ToShape}; +pub use self::{ + create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, ToShape}, + element::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}, + shape::{Shape, SymbolicDimensions} +}; use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; use crate::{ AsPointer, error::Result, memory::{Allocator, MemoryInfo}, - ortsys, - tensor::{IntoTensorElementType, Shape, SymbolicDimensions, TensorElementType} + ortsys }; pub trait TensorValueTypeMarker: ValueTypeMarker { @@ -86,7 +91,7 @@ impl DynTensor { /// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned /// (CPU) memory for use with CUDA: /// ```no_run - /// # use ort::{memory::{Allocator, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}, session::Session, tensor::TensorElementType, value::DynTensor}; + /// # use ort::{memory::{Allocator, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}, session::Session, value::{DynTensor, TensorElementType}}; /// # fn main() -> ort::Result<()> { /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let allocator = Allocator::new( @@ -345,10 +350,9 @@ mod tests { #[cfg(feature = "ndarray")] use ndarray::{ArcArray1, Array1, CowArray}; - use super::Tensor; + use super::{Shape, SymbolicDimensions, Tensor, TensorElementType}; use crate::{ memory::Allocator, - tensor::{Shape, SymbolicDimensions, TensorElementType}, value::{TensorRef, ValueType} }; diff --git a/src/tensor/mod.rs b/src/value/impl_tensor/shape.rs similarity index 92% rename from src/tensor/mod.rs rename to src/value/impl_tensor/shape.rs index 1e739cc..f1148ed 100644 --- a/src/tensor/mod.rs +++ b/src/value/impl_tensor/shape.rs @@ -1,9 +1,5 @@ //! Traits and types related to [`Tensor`](crate::value::Tensor)s. -#[cfg(feature = "ndarray")] -mod ndarray; -mod types; - use alloc::{string::String, vec::Vec}; use core::{ fmt, @@ -12,10 +8,6 @@ use core::{ use smallvec::{SmallVec, smallvec}; -#[cfg(all(feature = "ndarray", feature = "std"))] -pub use self::ndarray::ArrayExtensions; -pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; - #[derive(Default, Clone, PartialEq, Eq)] pub struct Shape { inner: SmallVec<[i64; 4]> diff --git a/src/value/mod.rs b/src/value/mod.rs index 40d1c3f..e23415a 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -37,8 +37,9 @@ pub use self::{ DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker }, impl_tensor::{ - DefiniteTensorValueTypeMarker, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, OwnedTensorArrayData, Tensor, TensorArrayData, - TensorArrayDataMut, TensorArrayDataParts, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, ToShape + DefiniteTensorValueTypeMarker, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, IntoTensorElementType, OwnedTensorArrayData, + PrimitiveTensorElementType, Shape, SymbolicDimensions, Tensor, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, TensorElementType, TensorRef, + TensorRefMut, TensorValueType, TensorValueTypeMarker, ToShape, Utf8Data }, r#type::{Outlet, ValueType} }; @@ -481,7 +482,6 @@ impl AsPointer for Value { #[cfg(test)] mod tests { use super::{DynTensorValueType, Map, Sequence, Tensor, TensorRef, TensorRefMut, TensorValueType}; - use crate::memory::Allocator; #[test] fn test_casting_tensor() -> crate::Result<()> { diff --git a/src/value/type.rs b/src/value/type.rs index f103d30..289fae9 100644 --- a/src/value/type.rs +++ b/src/value/type.rs @@ -12,39 +12,45 @@ use smallvec::{SmallVec, smallvec}; use crate::{ Result, ortsys, - tensor::{Shape, SymbolicDimensions, TensorElementType}, - util::{self, run_on_drop, with_cstr, with_cstr_ptr_array} + util::{self, run_on_drop, with_cstr, with_cstr_ptr_array}, + value::{Shape, SymbolicDimensions, TensorElementType} }; /// The type of a [`Value`][super::Value], or a session input/output. /// /// ``` /// # use std::sync::Arc; -/// # use ort::{session::Session, tensor::{Shape, SymbolicDimensions}, value::{ValueType, Tensor}, tensor::TensorElementType}; +/// # use ort::{session::Session, value::{ValueType, Tensor, Shape, SymbolicDimensions, TensorElementType}}; /// # fn main() -> ort::Result<()> { /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// // `ValueType`s can be obtained from session inputs/outputs: /// let input = &session.inputs()[0]; -/// assert_eq!(input.dtype(), &ValueType::Tensor { -/// ty: TensorElementType::Float32, -/// // Our model's input has 3 dynamic dimensions, represented by -1 -/// shape: Shape::new([-1, -1, -1, 3]), -/// // Dynamic dimensions may also have names. -/// dimension_symbols: SymbolicDimensions::new([ -/// "unk__31".to_string(), -/// "unk__32".to_string(), -/// "unk__33".to_string(), -/// String::default() -/// ]) -/// }); +/// assert_eq!( +/// input.dtype(), +/// &ValueType::Tensor { +/// ty: TensorElementType::Float32, +/// // Our model's input has 3 dynamic dimensions, represented by -1 +/// shape: Shape::new([-1, -1, -1, 3]), +/// // Dynamic dimensions may also have names. +/// dimension_symbols: SymbolicDimensions::new([ +/// "unk__31".to_string(), +/// "unk__32".to_string(), +/// "unk__33".to_string(), +/// String::default() +/// ]) +/// } +/// ); /// /// // ...or by `Value`s created in Rust or output by a session. /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; -/// assert_eq!(value.dtype(), &ValueType::Tensor { -/// ty: TensorElementType::Int64, -/// shape: Shape::new([5]), -/// dimension_symbols: SymbolicDimensions::new([String::default()]) -/// }); +/// assert_eq!( +/// value.dtype(), +/// &ValueType::Tensor { +/// ty: TensorElementType::Int64, +/// shape: Shape::new([5]), +/// dimension_symbols: SymbolicDimensions::new([String::default()]) +/// } +/// ); /// # Ok(()) /// # } /// ``` @@ -215,7 +221,7 @@ impl ValueType { /// Returns the element type of this value type if it is a tensor, or `None` if it is a sequence or map. /// /// ``` - /// # use ort::{tensor::TensorElementType, value::Tensor}; + /// # use ort::value::{Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; /// assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Int64)); @@ -374,7 +380,7 @@ mod tests { use super::ValueType; use crate::{ ortsys, - tensor::{Shape, SymbolicDimensions, TensorElementType} + value::{Shape, SymbolicDimensions, TensorElementType} }; #[test] diff --git a/tests/leak-check/main.rs b/tests/leak-check/main.rs index a64d544..c314c02 100644 --- a/tests/leak-check/main.rs +++ b/tests/leak-check/main.rs @@ -8,8 +8,7 @@ use ort::{ kernel::{Kernel, KernelAttributes, KernelContext} }, session::{RunOptions, Session}, - tensor::TensorElementType, - value::Tensor + value::{Tensor, TensorElementType} }; struct CustomOpOne; diff --git a/tests/mnist.rs b/tests/mnist.rs index 8c766ae..59ce4e0 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -4,7 +4,7 @@ use image::{ImageBuffer, Luma, Pixel, imageops::FilterType}; use ort::{ inputs, session::{Session, builder::GraphOptimizationLevel}, - tensor::ArrayExtensions, + util::ArrayExt, value::TensorRef }; diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index 59d4b63..183544e 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -9,7 +9,7 @@ use ndarray::s; use ort::{ Error, inputs, session::{Session, builder::GraphOptimizationLevel}, - tensor::ArrayExtensions, + util::ArrayExt, value::TensorRef };