refactor!: allow zero-copy from_array for array views with TensorRef

This has all sorts of fun breaking changes:
- `ort::inputs!` no longer yields an `ort::Result<...>` (thank God)
- `Tensor::from_array` now only accepts owned data.
- Introduce `TensorRef::from_array_view` and `TensorRefMut::from_array_view_mut`.
- `TryFrom<A>` is no longer implemented for `Tensor<T>` for any variants.

This opens the door to new optimizations on top of fixing a few unsoundness issues.

TODO: update docs
This commit is contained in:
Carson M.
2024-12-21 00:24:54 -06:00
parent d7d4493c3e
commit 9ea18d815b
23 changed files with 379 additions and 381 deletions

View File

@@ -4,7 +4,8 @@ use image::{ImageBuffer, Luma, Pixel, imageops::FilterType};
use ort::{
inputs,
session::{Session, builder::GraphOptimizationLevel},
tensor::ArrayExtensions
tensor::ArrayExtensions,
value::TensorRef
};
use test_log::test;
@@ -45,7 +46,7 @@ fn mnist_5() -> ort::Result<()> {
});
// Perform the inference
let outputs = session.run(inputs![array]?)?;
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
let mut probabilities: Vec<(usize, f32)> = outputs[0]
.try_extract_tensor()?

View File

@@ -10,7 +10,8 @@ use ndarray::s;
use ort::{
Error, inputs,
session::{Session, builder::GraphOptimizationLevel},
tensor::ArrayExtensions
tensor::ArrayExtensions,
value::TensorRef
};
use test_log::test;
@@ -70,7 +71,7 @@ fn squeezenet_mushroom() -> ort::Result<()> {
}
// Perform the inference
let outputs = session.run(inputs![array]?)?;
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
// Downloaded model does not have a softmax as final layer; call softmax on second axis
// and iterate on resulting probabilities, creating an index to later access labels.

View File

@@ -4,7 +4,8 @@ use image::RgbImage;
use ndarray::{Array, ArrayViewD, CowArray, Ix4};
use ort::{
inputs,
session::{Session, builder::GraphOptimizationLevel}
session::{Session, builder::GraphOptimizationLevel},
value::TensorRef
};
use test_log::test;
@@ -69,7 +70,7 @@ fn upsample() -> ort::Result<()> {
let array = convert_image_to_cow_array(&image_buffer);
// Perform the inference
let outputs = session.run(inputs![&array]?)?;
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
assert_eq!(outputs.len(), 1);
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
@@ -106,7 +107,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
let array = convert_image_to_cow_array(&image_buffer);
// Perform the inference
let outputs = session.run(inputs![&array]?)?;
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
assert_eq!(outputs.len(), 1);
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;

View File

@@ -26,11 +26,7 @@ fn vectorizer() -> ort::Result<()> {
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
// Just one input
let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?;
// Perform the inference
let outputs = session.run(input_tensor_values)?;
let outputs = session.run(inputs![Tensor::from_string_array(&array)?])?;
assert_eq!(outputs[0].try_extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
Ok(())