mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: custom ops & mutable tensors
This commit is contained in:
18
tests/data/custom_op_test.onnx
Normal file
18
tests/data/custom_op_test.onnx
Normal file
@@ -0,0 +1,18 @@
|
||||
OnnxRuntime:Č
|
||||
8
|
||||
input_1
|
||||
input_2output_1"CustomOpOne:
|
||||
test.customop
|
||||
.
|
||||
output_1output"CustomOpTwo:
|
||||
test.customopCustomOpTestZ
|
||||
input_1
|
||||
|
||||
|
||||
Z
|
||||
input_2
|
||||
|
||||
|
||||
b
|
||||
output
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::Path;
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel};
|
||||
use ort::{inputs, ArrayExtensions, GraphOptimizationLevel, Session, Tensor};
|
||||
use ort::{inputs, ArrayExtensions, GraphOptimizationLevel, Session};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
@@ -43,8 +43,13 @@ fn mnist_5() -> ort::Result<()> {
|
||||
// Perform the inference
|
||||
let outputs = session.run(inputs![array]?)?;
|
||||
|
||||
let output: Tensor<_> = outputs[0].extract_tensor()?;
|
||||
let mut probabilities: Vec<(usize, f32)> = output.view().softmax(ndarray::Axis(1)).iter().copied().enumerate().collect::<Vec<_>>();
|
||||
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
||||
.extract_tensor()?
|
||||
.softmax(ndarray::Axis(1))
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Sort probabilities so highest is at beginning of vector.
|
||||
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::{
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
|
||||
use ndarray::s;
|
||||
use ort::{inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, Session, Tensor};
|
||||
use ort::{inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, Session};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
@@ -70,8 +70,13 @@ fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
|
||||
// Downloaded model does not have a softmax as final layer; call softmax on second axis
|
||||
// and iterate on resulting probabilities, creating an index to later access labels.
|
||||
let output: Tensor<_> = outputs[0].extract_tensor()?;
|
||||
let mut probabilities: Vec<(usize, f32)> = output.view().softmax(ndarray::Axis(1)).iter().copied().enumerate().collect::<Vec<_>>();
|
||||
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
||||
.extract_tensor()?
|
||||
.softmax(ndarray::Axis(1))
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.collect::<Vec<_>>();
|
||||
// Sort probabilities so highest is at beginning of vector.
|
||||
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::path::Path;
|
||||
|
||||
use image::RgbImage;
|
||||
use ndarray::{Array, CowArray, Ix4};
|
||||
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
|
||||
use ndarray::{Array, ArrayViewD, CowArray, Ix4};
|
||||
use ort::{inputs, GraphOptimizationLevel, Session};
|
||||
use test_log::test;
|
||||
|
||||
fn load_input_image<P: AsRef<Path>>(name: P) -> RgbImage {
|
||||
@@ -69,10 +69,10 @@ fn upsample() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![&array]?)?;
|
||||
|
||||
assert_eq!(outputs.len(), 1);
|
||||
let output: Tensor<f32> = outputs[0].extract_tensor()?;
|
||||
let output: ArrayViewD<f32> = outputs[0].extract_tensor()?;
|
||||
|
||||
// The image should have doubled in size
|
||||
assert_eq!(output.view().shape(), [1, 448, 448, 3]);
|
||||
assert_eq!(output.shape(), [1, 448, 448, 3]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -106,10 +106,10 @@ fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![&array]?)?;
|
||||
|
||||
assert_eq!(outputs.len(), 1);
|
||||
let output: Tensor<f32> = outputs[0].extract_tensor()?;
|
||||
let output: ArrayViewD<f32> = outputs[0].extract_tensor()?;
|
||||
|
||||
// The image should have doubled in size
|
||||
assert_eq!(output.view().shape(), [1, 448, 448, 3]);
|
||||
assert_eq!(output.shape(), [1, 448, 448, 3]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -26,12 +26,7 @@ fn vectorizer() -> ort::Result<()> {
|
||||
|
||||
// Perform the inference
|
||||
let outputs = session.run(input_tensor_values)?;
|
||||
assert_eq!(
|
||||
*outputs[0].extract_tensor::<f32>()?.view(),
|
||||
ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
.unwrap()
|
||||
.view()
|
||||
);
|
||||
assert_eq!(outputs[0].extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user