feat: value specialization (#178)

This commit is contained in:
Carson M
2024-03-27 19:45:41 -05:00
committed by GitHub
parent 414befaf3f
commit 393f25f6e4
22 changed files with 958 additions and 287 deletions

View File

@@ -44,7 +44,7 @@ fn mnist_5() -> ort::Result<()> {
let outputs = session.run(inputs![array]?)?;
let mut probabilities: Vec<(usize, f32)> = outputs[0]
.extract_tensor()?
.try_extract_tensor()?
.softmax(ndarray::Axis(1))
.iter()
.copied()

View File

@@ -71,7 +71,7 @@ fn squeezenet_mushroom() -> ort::Result<()> {
// Downloaded model does not have a softmax as final layer; call softmax on second axis
// and iterate on resulting probabilities, creating an index to later access labels.
let mut probabilities: Vec<(usize, f32)> = outputs[0]
.extract_tensor()?
.try_extract_tensor()?
.softmax(ndarray::Axis(1))
.iter()
.copied()

View File

@@ -69,7 +69,7 @@ fn upsample() -> ort::Result<()> {
let outputs = session.run(inputs![&array]?)?;
assert_eq!(outputs.len(), 1);
let output: ArrayViewD<f32> = outputs[0].extract_tensor()?;
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
// The image should have doubled in size
assert_eq!(output.shape(), [1, 448, 448, 3]);
@@ -106,7 +106,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
let outputs = session.run(inputs![&array]?)?;
assert_eq!(outputs.len(), 1);
let output: ArrayViewD<f32> = outputs[0].extract_tensor()?;
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
// The image should have doubled in size
assert_eq!(output.shape(), [1, 448, 448, 3]);

View File

@@ -3,7 +3,7 @@
use std::path::Path;
use ndarray::{ArrayD, IxDyn};
use ort::{inputs, GraphOptimizationLevel, Session, Value};
use ort::{inputs, DynTensor, GraphOptimizationLevel, Session};
use test_log::test;
#[test]
@@ -22,11 +22,11 @@ fn vectorizer() -> ort::Result<()> {
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
// Just one input
let input_tensor_values = inputs![Value::from_string_array(session.allocator(), &array)?]?;
let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?;
// Perform the inference
let outputs = session.run(input_tensor_values)?;
assert_eq!(outputs[0].extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
assert_eq!(outputs[0].try_extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
Ok(())
}