mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
108 lines
3.4 KiB
Rust
108 lines
3.4 KiB
Rust
use std::{
|
|
fs,
|
|
io::{BufRead, BufReader},
|
|
path::Path
|
|
};
|
|
|
|
use image::{ImageBuffer, Pixel, Rgb, imageops::FilterType};
|
|
use ndarray::s;
|
|
use ort::{
|
|
Error, inputs,
|
|
session::{Session, builder::GraphOptimizationLevel},
|
|
util::ArrayExt,
|
|
value::TensorRef
|
|
};
|
|
|
|
#[test]
|
|
fn squeezenet_mushroom() -> ort::Result<()> {
|
|
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
|
|
|
ort::init().with_name("integration_test").commit();
|
|
|
|
let mut session = Session::builder()?
|
|
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
|
.with_intra_threads(1)?
|
|
.commit_from_url("https://cdn.pyke.io/0/pyke:ort-rs/example-models@0.0.0/squeezenet.onnx")
|
|
.expect("Could not download model from file");
|
|
|
|
let class_labels = get_imagenet_labels()?;
|
|
|
|
let input0_shape = {
|
|
let metadata = session.metadata()?;
|
|
assert_eq!(metadata.name().as_deref(), Some("main_graph"));
|
|
assert_eq!(metadata.producer().as_deref(), Some("pytorch"));
|
|
|
|
let input0_shape = session.inputs()[0].dtype().tensor_shape().expect("input0 to be a tensor type");
|
|
let output0_shape = session.outputs()[0].dtype().tensor_shape().expect("output0 to be a tensor type");
|
|
|
|
assert_eq!(&**input0_shape, [1, 3, 224, 224]);
|
|
assert_eq!(&**output0_shape, [1, 1000]);
|
|
|
|
input0_shape
|
|
};
|
|
|
|
// Load image and resize to model's shape, converting to RGB format
|
|
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))
|
|
.unwrap()
|
|
.resize(input0_shape[2] as u32, input0_shape[3] as u32, FilterType::Nearest)
|
|
.to_rgb8();
|
|
|
|
// Python:
|
|
// # image[y, x, RGB]
|
|
// # x==0 --> left
|
|
// # y==0 --> top
|
|
|
|
// See https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb
|
|
// for pre-processing image.
|
|
// WARNING: Note order of declaration of arguments: (_,c,j,i)
|
|
let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| {
|
|
let pixel = image_buffer.get_pixel(i as u32, j as u32);
|
|
let channels = pixel.channels();
|
|
|
|
// range [0, 255] -> range [0, 1]
|
|
(channels[c] as f32) / 255.0
|
|
});
|
|
|
|
// Normalize channels to mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]
|
|
let mean = [0.485, 0.456, 0.406];
|
|
let std = [0.229, 0.224, 0.225];
|
|
for c in 0..3 {
|
|
let mut channel_array = array.slice_mut(s![0, c, .., ..]);
|
|
channel_array -= mean[c];
|
|
channel_array /= std[c];
|
|
}
|
|
|
|
// Perform the inference
|
|
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
|
|
|
|
// Downloaded model does not have a softmax as final layer; call softmax on second axis
|
|
// and iterate on resulting probabilities, creating an index to later access labels.
|
|
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
|
.try_extract_array()?
|
|
.softmax(ndarray::Axis(1))
|
|
.iter()
|
|
.copied()
|
|
.enumerate()
|
|
.collect::<Vec<_>>();
|
|
// Sort probabilities so highest is at beginning of vector.
|
|
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
|
|
|
assert_eq!(class_labels[probabilities[0].0], "n07734744 mushroom", "Expecting class for {} to be a mushroom", IMAGE_TO_LOAD);
|
|
assert_eq!(probabilities[0].0, 947, "Expecting class for {} to be a mushroom (index 947 in labels file)", IMAGE_TO_LOAD);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn get_imagenet_labels() -> ort::Result<Vec<String>> {
|
|
let file = BufReader::new(
|
|
fs::File::open(
|
|
Path::new(env!("CARGO_MANIFEST_DIR"))
|
|
.join("tests")
|
|
.join("data")
|
|
.join("imagenet-classes.txt")
|
|
)
|
|
.unwrap()
|
|
);
|
|
file.lines().map(|line| line.map_err(Error::wrap)).collect()
|
|
}
|