mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: value specialization (#178)
This commit is contained in:
@@ -44,7 +44,7 @@ fn mnist_5() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![array]?)?;
|
||||
|
||||
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
||||
.extract_tensor()?
|
||||
.try_extract_tensor()?
|
||||
.softmax(ndarray::Axis(1))
|
||||
.iter()
|
||||
.copied()
|
||||
|
||||
@@ -71,7 +71,7 @@ fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
// Downloaded model does not have a softmax as final layer; call softmax on second axis
|
||||
// and iterate on resulting probabilities, creating an index to later access labels.
|
||||
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
||||
.extract_tensor()?
|
||||
.try_extract_tensor()?
|
||||
.softmax(ndarray::Axis(1))
|
||||
.iter()
|
||||
.copied()
|
||||
|
||||
@@ -69,7 +69,7 @@ fn upsample() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![&array]?)?;
|
||||
|
||||
assert_eq!(outputs.len(), 1);
|
||||
let output: ArrayViewD<f32> = outputs[0].extract_tensor()?;
|
||||
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
|
||||
|
||||
// The image should have doubled in size
|
||||
assert_eq!(output.shape(), [1, 448, 448, 3]);
|
||||
@@ -106,7 +106,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![&array]?)?;
|
||||
|
||||
assert_eq!(outputs.len(), 1);
|
||||
let output: ArrayViewD<f32> = outputs[0].extract_tensor()?;
|
||||
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
|
||||
|
||||
// The image should have doubled in size
|
||||
assert_eq!(output.shape(), [1, 448, 448, 3]);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use std::path::Path;
|
||||
|
||||
use ndarray::{ArrayD, IxDyn};
|
||||
use ort::{inputs, GraphOptimizationLevel, Session, Value};
|
||||
use ort::{inputs, DynTensor, GraphOptimizationLevel, Session};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
@@ -22,11 +22,11 @@ fn vectorizer() -> ort::Result<()> {
|
||||
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
|
||||
|
||||
// Just one input
|
||||
let input_tensor_values = inputs![Value::from_string_array(session.allocator(), &array)?]?;
|
||||
let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?;
|
||||
|
||||
// Perform the inference
|
||||
let outputs = session.run(input_tensor_values)?;
|
||||
assert_eq!(outputs[0].extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
|
||||
assert_eq!(outputs[0].try_extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user