diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ff5a07..ca462fc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,6 +68,11 @@ jobs: $VCINSTALLDIR = Join-Path $VSINSTALLDIR "VC" $LLVM_ROOT = Join-Path $VCINSTALLDIR "Tools\Llvm\x64" echo "PATH=$Env:PATH;${LLVM_ROOT}\bin" >> $Env:GITHUB_ENV + - name: Install native dependencies for Linux + if: matrix.platform.os == 'ubuntu-latest' + shell: bash + run: | + apt install pkg-config libfreetype6-dev libexpat1-dev libxcb-composite0-dev libssl-dev libx11-dev libfontconfig1-dev - name: Build/test uses: houseabsolute/actions-rust-cross@v0 with: diff --git a/Cargo.toml b/Cargo.toml index bc2ef4b..87f275a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,6 @@ once_cell = "1.18" tracing = "0.1" half = { version = "2.1", optional = true } - [target.'cfg(unix)'.dependencies] libc = "0.2" diff --git a/examples/gpt.rs b/examples/gpt.rs index 6f236ef..7740b9c 100644 --- a/examples/gpt.rs +++ b/examples/gpt.rs @@ -1,6 +1,6 @@ use std::io::{self, Write}; -use ndarray::{array, concatenate, s, Array1, Axis, CowArray}; +use ndarray::{array, concatenate, s, Array1, Axis}; use ort::{ download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProviderOptions, Environment, ExecutionProvider, GraphOptimizationLevel, OrtOwnedTensor, OrtResult, SessionBuilder @@ -33,15 +33,14 @@ fn main() -> OrtResult<()> { let tokens = tokenizer.encode(PROMPT, false).unwrap(); let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); - let mut tokens = CowArray::from(Array1::from_iter(tokens.iter().cloned())); + let mut tokens = Array1::from_iter(tokens.iter().cloned()); print!("{PROMPT}"); stdout.flush().unwrap(); for _ in 0..GEN_TOKENS { - let n_tokens = tokens.shape()[0]; - let array = tokens.clone().insert_axis(Axis(0)).into_shape((1, 1, n_tokens)).unwrap().into_dyn(); - let outputs = session.run(inputs![&array]?)?; + let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1)); + let outputs = session.run(inputs![array]?)?; let generated_tokens: OrtOwnedTensor = outputs["output1"].extract_tensor()?; let generated_tokens = generated_tokens.view(); @@ -56,7 +55,7 @@ fn main() -> OrtResult<()> { probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); let token = probabilities[rng.gen_range(0..=TOP_K)].0; - tokens = CowArray::from(concatenate![Axis(0), tokens, array![token.try_into().unwrap()]]); + tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]]; let token_str = tokenizer.decode(&[token as _], true).unwrap(); print!("{}", token_str); diff --git a/examples/yolov8.rs b/examples/yolov8.rs index 5a75a87..58d41e0 100644 --- a/examples/yolov8.rs +++ b/examples/yolov8.rs @@ -74,7 +74,7 @@ fn main() -> OrtResult<()> { let original_img = image::open("tests/data/baseball.jpg").unwrap(); let (img_width, img_height) = (original_img.width(), original_img.height()); let img = original_img.resize_exact(640, 640, FilterType::CatmullRom); - let mut input = Array::zeros((1, 3, 640, 640)).into_dyn(); + let mut input = Array::zeros((1, 3, 640, 640)); for pixel in img.pixels() { let x = pixel.0 as _; let y = pixel.1 as _; @@ -89,8 +89,7 @@ fn main() -> OrtResult<()> { let model = SessionBuilder::new(&env).unwrap().with_model_from_file(path).unwrap(); // Run YOLOv8 inference - let input_as_values = &input.as_standard_layout(); - let outputs = model.run(inputs!["images" => input_as_values]).unwrap(); + let outputs = model.run(inputs!["images" => input]?).unwrap(); let output = outputs["output0"].extract_tensor::().unwrap().view().t().into_owned(); let mut boxes = Vec::new(); diff --git a/src/session/input.rs b/src/session/input.rs index 35ebcb7..685570a 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -34,15 +34,71 @@ impl<'s> From> for SessionInputs<'s> { } } +/// Construct the inputs to a session from an array, a map, or an IO binding. +/// +/// The result of this macro is an `Result`, so make sure you `?` on the result. +/// +/// For tensors, note that using certain array structures can have performance implications. +/// - `&CowArray`, `ArrayView` will **always** be copied. +/// - `Array`, `&mut ArcArray` will only be copied **if the tensor is not contiguous** (i.e. has been reshaped). +/// +/// # Example +/// +/// ## Array of tensors +/// +/// ```no_run +/// # use std::{error::Error, sync::Arc}; +/// # use ndarray::Array1; +/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder}; +/// # fn main() -> Result<(), Box> { +/// # let environment = Environment::default().into_arc(); +/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?; +/// let _ = session.run(ort::inputs![Array1::from_vec(vec![1, 2, 3, 4, 5])]?); +/// # Ok(()) +/// # } +/// ``` +/// +/// ## Map of named tensors +/// +/// ```no_run +/// # use std::{error::Error, sync::Arc}; +/// # use ndarray::Array1; +/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder}; +/// # fn main() -> Result<(), Box> { +/// # let environment = Environment::default().into_arc(); +/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?; +/// let _ = session.run(ort::inputs! { +/// "tokens" => Array1::from_vec(vec![1, 2, 3, 4, 5]) +/// }?); +/// # Ok(()) +/// # } +/// ``` +/// +/// ## IOBinding +/// +/// ```no_run +/// # use std::{error::Error, sync::Arc}; +/// # use ndarray::Array1; +/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder}; +/// # fn main() -> Result<(), Box> { +/// # let environment = Environment::default().into_arc(); +/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?; +/// let mut binding = session.create_binding()?; +/// let _ = session.run(ort::inputs!(bind = binding)?); +/// # Ok(()) +/// # } +/// ``` #[macro_export] macro_rules! inputs { - (bind = $($v:expr),+) => ($v); - ($($v:expr),+ $(,)?) => ([$(std::convert::TryInto::<$crate::Value>::try_into($v).map_err($crate::OrtError::from),)+].into_iter().collect::<$crate::OrtResult<$crate::smallvec::SmallVec<_>>>()); + (bind = $v:expr) => ($crate::OrtResult::Ok($v)); + ($($v:expr),+ $(,)?) => ( + [$(std::convert::TryInto::<$crate::Value>::try_into($v).map_err($crate::OrtError::from),)+] + .into_iter() + .collect::<$crate::OrtResult<$crate::smallvec::SmallVec<_>>>() + ); ($($n:expr => $v:expr),+ $(,)?) => {{ - let mut inputs = std::collections::HashMap::<_, $crate::Value>::new(); - $( - inputs.insert($n, std::convert::TryInto::<$crate::Value>::try_into($v)?); - )+ - inputs + [$(std::convert::TryInto::<$crate::Value>::try_into($v).map_err($crate::OrtError::from).map(|v| ($n, v)),)+] + .into_iter() + .collect::<$crate::OrtResult>>() }}; } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index f2779d9..f811d3c 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -179,16 +179,6 @@ impl<'a> Utf8Data for &'a str { } } -impl IntoTensorElementDataType for T { - fn tensor_element_data_type() -> TensorElementDataType { - TensorElementDataType::String - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - Some(self.utf8_bytes()) - } -} - /// Trait used to map ONNX Runtime types to Rust types. pub trait TensorDataToType: Sized + fmt::Debug + Clone { /// The tensor element type that this type can extract from. @@ -326,3 +316,13 @@ pub enum DataType { Sequence(Box), Map { key: TensorElementDataType, value: TensorElementDataType } } + +impl DataType { + /// Returns the dimensions of this data type if it is a tensor, or `None` if it is a sequence or map. + pub fn tensor_dimensions(&self) -> Option<&Vec> { + match self { + DataType::Tensor { dimensions, .. } => Some(dimensions), + _ => None + } + } +} diff --git a/src/value.rs b/src/value.rs index c9b789f..f872528 100644 --- a/src/value.rs +++ b/src/value.rs @@ -7,7 +7,7 @@ use crate::{ memory::{Allocator, MemoryInfo}, ortsys, session::SharedSessionInner, - tensor::{IntoTensorElementDataType, OrtOwnedTensor, TensorDataToType, TensorElementDataType}, + tensor::{IntoTensorElementDataType, OrtOwnedTensor, TensorDataToType, TensorElementDataType, Utf8Data}, AllocatorType, MemType, OrtError, OrtResult }; @@ -32,29 +32,6 @@ pub enum DynArrayRef<'v> { String(CowArray<'v, String, IxDyn>) } -impl<'v> DynArrayRef<'v> { - pub fn shape(&self) -> &[usize] { - match self { - DynArrayRef::Float(x) => x.shape(), - #[cfg(feature = "half")] - DynArrayRef::Float16(x) => x.shape(), - #[cfg(feature = "half")] - DynArrayRef::Bfloat16(x) => x.shape(), - DynArrayRef::Uint8(x) => x.shape(), - DynArrayRef::Int8(x) => x.shape(), - DynArrayRef::Uint16(x) => x.shape(), - DynArrayRef::Int16(x) => x.shape(), - DynArrayRef::Int32(x) => x.shape(), - DynArrayRef::Int64(x) => x.shape(), - DynArrayRef::Bool(x) => x.shape(), - DynArrayRef::Double(x) => x.shape(), - DynArrayRef::Uint32(x) => x.shape(), - DynArrayRef::Uint64(x) => x.shape(), - DynArrayRef::String(x) => x.shape() - } - } -} - macro_rules! impl_convert_trait { ($type_:ty, $variant:expr) => { impl<'v, D: Dimension> From> for DynArrayRef<'v> { @@ -179,18 +156,15 @@ impl Value { pub trait OrtInput { type Item; - fn get(&self) -> (Vec, &[Self::Item]); - fn get_mut(&mut self) -> (Vec, *mut Self::Item, usize, Box); + fn ref_parts(&self) -> (Vec, &[Self::Item]); + fn into_parts(self) -> (Vec, *mut Self::Item, usize, Box); } impl Value { - /// Construct a [`Value`] from a Rust-owned [`CowArray`]. + /// Construct a [`Value`] from a Rust-owned array. /// /// `allocator` is required to be `Some` when converting a String tensor. See [`crate::Session::allocator`]. - pub fn from_array( - allocator: Option<&Allocator>, - mut input: impl OrtInput - ) -> OrtResult { + pub fn from_array(input: impl OrtInput) -> OrtResult { let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemType::Default)?; let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); @@ -210,7 +184,7 @@ impl Value { | TensorElementDataType::Bool => { // primitive data is already suitably laid out in memory; provide it to // onnxruntime as is - let (shape, ptr, ptr_len, guard) = input.get_mut(); + let (shape, ptr, ptr_len, guard) = input.into_parts(); let shape_ptr: *const i64 = shape.as_ptr(); let shape_len = shape.len(); @@ -238,7 +212,7 @@ impl Value { #[cfg(feature = "half")] TensorElementDataType::Bfloat16 | TensorElementDataType::Float16 => { // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime - let (shape, ptr, ptr_len, guard) = input.get_mut(); + let (shape, ptr, ptr_len, guard) = input.into_parts(); let shape_ptr: *const i64 = shape.as_ptr(); let shape_len = shape.len(); @@ -263,34 +237,7 @@ impl Value { assert_eq!(is_tensor, 1); guard } - TensorElementDataType::String => { - let allocator = allocator.ok_or(OrtError::StringTensorRequiresAllocator)?; - - let (shape, data) = input.get(); - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - // create tensor without data -- data is filled in later - ortsys![ - unsafe CreateTensorAsOrtValue(allocator.ptr, shape_ptr, shape_len as _, T::tensor_element_data_type().into(), value_ptr_ptr) - -> OrtError::CreateTensor - ]; - - // create null-terminated copies of each string, as per `FillStringTensor` docs - let null_terminated_copies: Vec = data - .iter() - .map(|elt| { - let slice = elt.try_utf8_bytes().expect("String data type must provide utf8 bytes"); - ffi::CString::new(slice) - }) - .collect::, _>>() - .map_err(OrtError::FfiStringNull)?; - - let string_pointers = null_terminated_copies.iter().map(|cstring| cstring.as_ptr()).collect::>(); - - ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> OrtError::FillStringTensor]; - Box::new(()) - } + TensorElementDataType::String => unreachable!() }; assert_non_null_pointer(value_ptr, "Value")?; @@ -304,6 +251,50 @@ impl Value { }) } + /// Construct a [`Value`] from a Rust-owned [`CowArray`]. + /// + /// `allocator` is required to be `Some` when converting a String tensor. See [`crate::Session::allocator`]. + pub fn from_string_array(allocator: &Allocator, input: impl OrtInput) -> OrtResult { + let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemType::Default)?; + + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + let value_ptr_ptr: *mut *mut ort_sys::OrtValue = &mut value_ptr; + + let (shape, data) = input.ref_parts(); + let shape_ptr: *const i64 = shape.as_ptr(); + let shape_len = shape.len(); + + // create tensor without data -- data is filled in later + ortsys![ + unsafe CreateTensorAsOrtValue(allocator.ptr, shape_ptr, shape_len as _, TensorElementDataType::String.into(), value_ptr_ptr) + -> OrtError::CreateTensor + ]; + + // create null-terminated copies of each string, as per `FillStringTensor` docs + let null_terminated_copies: Vec = data + .iter() + .map(|elt| { + let slice = elt.utf8_bytes(); + ffi::CString::new(slice) + }) + .collect::, _>>() + .map_err(OrtError::FfiStringNull)?; + + let string_pointers = null_terminated_copies.iter().map(|cstring| cstring.as_ptr()).collect::>(); + + ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> OrtError::FillStringTensor]; + + assert_non_null_pointer(value_ptr, "Value")?; + + Ok(Value { + inner: ValueInner::RustOwned { + ptr: value_ptr, + _array: Box::new(()), + _memory_info: memory_info + } + }) + } + pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue { match &self.inner { ValueInner::CppOwned { ptr, .. } => *ptr, @@ -319,32 +310,19 @@ impl Value { } } -impl<'i, 'v, T> TryFrom<&'i CowArray<'v, T, IxDyn>> for Value -where - 'i: 'v, - T: IntoTensorElementDataType + Debug + Clone + 'static, - DynArrayRef<'v>: From> -{ - type Error = OrtError; - - fn try_from(value: &'i CowArray<'v, T, IxDyn>) -> OrtResult { - Value::from_array(None, value) - } -} - -impl<'i, 'v, T: Clone + 'static> OrtInput for &'i CowArray<'v, T, IxDyn> +impl<'i, 'v, T: Clone + 'static, D: Dimension + 'static> OrtInput for &'i CowArray<'v, T, D> where 'i: 'v { type Item = T; - fn get(&self) -> (Vec, &[Self::Item]) { + fn ref_parts(&self) -> (Vec, &[Self::Item]) { let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); let data = self.as_slice().expect("tensor should be contiguous"); (shape, data) } - fn get_mut(&mut self) -> (Vec, *mut Self::Item, usize, Box) { + fn into_parts(self) -> (Vec, *mut Self::Item, usize, Box) { // This will result in a copy in either form of the CowArray let mut contiguous_array = self.as_standard_layout().into_owned(); let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); @@ -355,16 +333,16 @@ where } } -impl OrtInput for &mut ArcArray { +impl OrtInput for &mut ArcArray { type Item = T; - fn get(&self) -> (Vec, &[Self::Item]) { + fn ref_parts(&self) -> (Vec, &[Self::Item]) { let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); let data = self.as_slice().expect("tensor should be contiguous"); (shape, data) } - fn get_mut(&mut self) -> (Vec, *mut Self::Item, usize, Box) { + fn into_parts(self) -> (Vec, *mut Self::Item, usize, Box) { if self.is_standard_layout() { // We can avoid the copy here and use the data as is let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); @@ -384,24 +362,104 @@ impl OrtInput for &mut ArcArray { } } +impl OrtInput for Array { + type Item = T; + + fn ref_parts(&self) -> (Vec, &[Self::Item]) { + let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + let data = self.as_slice().expect("tensor should be contiguous"); + (shape, data) + } + + fn into_parts(self) -> (Vec, *mut Self::Item, usize, Box) { + if self.is_standard_layout() { + // We can avoid the copy here and use the data as is + let mut guard = Box::new(self); + let shape: Vec = guard.shape().iter().map(|d| *d as i64).collect(); + let ptr = guard.as_mut_ptr(); + let ptr_len = guard.len(); + (shape, ptr, ptr_len, guard) + } else { + // Need to do a copy here to get data in to standard layout + let mut contiguous_array = self.as_standard_layout().into_owned(); + let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); + let ptr = contiguous_array.as_mut_ptr(); + let ptr_len: usize = contiguous_array.len(); + let guard = Box::new(contiguous_array); + (shape, ptr, ptr_len, guard) + } + } +} + +impl<'v, T: Clone + 'static, D: Dimension + 'static> OrtInput for ArrayView<'v, T, D> { + type Item = T; + + fn ref_parts(&self) -> (Vec, &[Self::Item]) { + let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + let data = self.as_slice().expect("tensor should be contiguous"); + (shape, data) + } + + fn into_parts(self) -> (Vec, *mut Self::Item, usize, Box) { + // This will result in a copy in either form of the ArrayView + let mut contiguous_array = self.as_standard_layout().into_owned(); + let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); + let ptr = contiguous_array.as_mut_ptr(); + let ptr_len = contiguous_array.len(); + let guard = Box::new(contiguous_array); + (shape, ptr, ptr_len, guard) + } +} + impl OrtInput for (Vec, Arc>) { type Item = T; - fn get(&self) -> (Vec, &[Self::Item]) { + fn ref_parts(&self) -> (Vec, &[Self::Item]) { let shape = self.0.clone(); let data = self.1.deref(); (shape, data) } - fn get_mut(&mut self) -> (Vec, *mut Self::Item, usize, Box) { + fn into_parts(mut self) -> (Vec, *mut Self::Item, usize, Box) { let shape = self.0.clone(); let ptr = std::sync::Arc::>::make_mut(&mut self.1).as_mut_ptr(); let ptr_len: usize = self.1.len(); - let guard = Box::new(self.clone()); + let guard = Box::new(Arc::clone(&self.1)); (shape, ptr, ptr_len, guard) } } +impl<'i, 'v, T: IntoTensorElementDataType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Value +where + 'i: 'v +{ + type Error = OrtError; + fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { + Value::from_array(arr) + } +} + +impl TryFrom<&mut ArcArray> for Value { + type Error = OrtError; + fn try_from(arr: &mut ArcArray) -> Result { + Value::from_array(arr) + } +} + +impl TryFrom> for Value { + type Error = OrtError; + fn try_from(arr: Array) -> Result { + Value::from_array(arr) + } +} + +impl<'v, T: IntoTensorElementDataType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Value { + type Error = OrtError; + fn try_from(arr: ArrayView<'v, T, D>) -> Result { + Value::from_array(arr) + } +} + impl Drop for Value { fn drop(&mut self) { ortsys![unsafe ReleaseValue(self.ptr())]; diff --git a/tests/mnist.rs b/tests/mnist.rs index c32e7ec..1fe6e8f 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -27,11 +27,11 @@ fn mnist_5() -> OrtResult<()> { assert_eq!(metadata.name()?, "CNTKGraph"); assert_eq!(metadata.producer()?, "CNTK"); - let input0_shape: Vec = session.inputs[0].map(|d| d.unwrap()).collect(); - let output0_shape: Vec = session.outputs[0].map(|d| d.unwrap()).collect(); + let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); + let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); - assert_eq!(input0_shape, [1, 1, 28, 28]); - assert_eq!(output0_shape, [1, 10]); + assert_eq!(input0_shape, &[1, 1, 28, 28]); + assert_eq!(output0_shape, &[1, 10]); // Load image and resize to model's shape, converting to RGB format let image_buffer: ImageBuffer, Vec> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD)) @@ -39,19 +39,16 @@ fn mnist_5() -> OrtResult<()> { .resize(input0_shape[2] as u32, input0_shape[3] as u32, FilterType::Nearest) .to_luma8(); - let array = ndarray::CowArray::from( - ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { - let pixel = image_buffer.get_pixel(i as u32, j as u32); - let channels = pixel.channels(); + let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); - // range [0, 255] -> range [0, 1] - (channels[c] as f32) / 255.0 - }) - .into_dyn() - ); + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); // Perform the inference - let outputs = session.run(inputs![&array]?)?; + let outputs = session.run(inputs![array]?)?; let output: OrtOwnedTensor<_> = outputs[0].extract_tensor()?; let mut probabilities: Vec<(usize, f32)> = output.view().softmax(ndarray::Axis(1)).iter().copied().enumerate().collect::>(); diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index e60bdeb..0de8a61 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -35,11 +35,11 @@ fn squeezenet_mushroom() -> OrtResult<()> { let class_labels = get_imagenet_labels()?; - let input0_shape: Vec = session.inputs[0].map(|d| d.unwrap()).collect(); - let output0_shape: Vec = session.outputs[0].map(|d| d.unwrap()).collect(); + let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); + let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); - assert_eq!(input0_shape, [1, 3, 224, 224]); - assert_eq!(output0_shape, [1, 1000]); + assert_eq!(input0_shape, &[1, 3, 224, 224]); + assert_eq!(output0_shape, &[1, 1000]); // Load image and resize to model's shape, converting to RGB format let image_buffer: ImageBuffer, Vec> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD)) @@ -55,16 +55,13 @@ fn squeezenet_mushroom() -> OrtResult<()> { // See https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb // for pre-processing image. // WARNING: Note order of declaration of arguments: (_,c,j,i) - let mut array = ndarray::CowArray::from( - ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| { - let pixel = image_buffer.get_pixel(i as u32, j as u32); - let channels = pixel.channels(); + let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); - // range [0, 255] -> range [0, 1] - (channels[c] as f32) / 255.0 - }) - .into_dyn() - ); + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); // Normalize channels to mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] let mean = [0.485, 0.456, 0.406]; @@ -76,7 +73,7 @@ fn squeezenet_mushroom() -> OrtResult<()> { } // Perform the inference - let outputs = session.run(inputs![&array]?)?; + let outputs = session.run(inputs![array]?)?; // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. diff --git a/tests/upsample.rs b/tests/upsample.rs index 6caff22..6f10350 100644 --- a/tests/upsample.rs +++ b/tests/upsample.rs @@ -1,7 +1,7 @@ use std::path::Path; use image::RgbImage; -use ndarray::{Array, CowArray, IxDyn}; +use ndarray::{Array, CowArray, Ix4}; use ort::{inputs, Environment, GraphOptimizationLevel, LoggingLevel, OrtOwnedTensor, OrtResult, SessionBuilder}; use test_log::test; @@ -12,11 +12,10 @@ fn load_input_image>(name: P) -> RgbImage { .to_rgb8() } -fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, IxDyn> { +fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, Ix4> { let array = Array::from_shape_vec((1, img.height() as usize, img.width() as usize, 3), img.to_vec()) .unwrap() - .map(|x| *x as f32 / 255.0) - .into_dyn(); + .map(|x| *x as f32 / 255.0); CowArray::from(array) } @@ -63,18 +62,15 @@ fn upsample() -> OrtResult<()> { assert_eq!(metadata.name()?, "tf2onnx"); assert_eq!(metadata.producer()?, "tf2onnx"); - assert_eq!(session.inputs[0].collect::>(), [None, None, None, Some(3)]); - assert_eq!(session.outputs[0].collect::>(), [None, None, None, Some(3)]); + assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]); + assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]); // Load image, converting to RGB format let image_buffer = load_input_image(IMAGE_TO_LOAD); let array = convert_image_to_cow_array(&image_buffer); - // Just one input - let input_tensor_values = inputs![&array]?; - // Perform the inference - let outputs = session.run(input_tensor_values)?; + let outputs = session.run(inputs![&array]?)?; assert_eq!(outputs.len(), 1); let output: OrtOwnedTensor = outputs[0].extract_tensor()?; @@ -107,8 +103,8 @@ fn upsample_with_ort_model() -> OrtResult<()> { .with_model_from_memory_directly(&session_data) // Zero-copy. .expect("Could not read model from memory"); - assert_eq!(session.inputs[0].collect::>(), [None, None, None, Some(3)]); - assert_eq!(session.outputs[0].collect::>(), [None, None, None, Some(3)]); + assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]); + assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]); // Load image, converting to RGB format let image_buffer = load_input_image(IMAGE_TO_LOAD); diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index dd12122..d768290 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -20,10 +20,10 @@ fn vectorizer() -> OrtResult<()> { assert_eq!(metadata.description()?, "test description"); assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value")); - let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap().into_dyn()); + let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()); // Just one input - let input_tensor_values = inputs![Value::from_array(Some(session.allocator()), &array)?]?; + let input_tensor_values = inputs![Value::from_string_array(session.allocator(), &array)?]?; // Perform the inference let outputs = session.run(input_tensor_values)?;