refactor!: make Session::run take &mut self

This commit is contained in:
Carson M.
2025-02-08 11:56:51 -06:00
parent 0f02d79a07
commit bd2aff711e
25 changed files with 193 additions and 168 deletions

View File

@@ -15,21 +15,25 @@ fn mnist_5() -> ort::Result<()> {
ort::init().with_name("integration_test").commit()?;
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")
.expect("Could not download model from file");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "CNTKGraph");
assert_eq!(metadata.producer()?, "CNTK");
let input0_shape = {
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "CNTKGraph");
assert_eq!(metadata.producer()?, "CNTK");
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
assert_eq!(input0_shape, &[1, 1, 28, 28]);
assert_eq!(output0_shape, &[1, 10]);
assert_eq!(input0_shape, &[1, 1, 28, 28]);
assert_eq!(output0_shape, &[1, 10]);
input0_shape
};
// Load image and resize to model's shape, converting to RGB format
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))

View File

@@ -21,23 +21,27 @@ fn squeezenet_mushroom() -> ort::Result<()> {
ort::init().with_name("integration_test").commit()?;
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx")
.expect("Could not download model from file");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "main_graph");
assert_eq!(metadata.producer()?, "pytorch");
let class_labels = get_imagenet_labels()?;
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
let input0_shape = {
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "main_graph");
assert_eq!(metadata.producer()?, "pytorch");
assert_eq!(input0_shape, &[1, 3, 224, 224]);
assert_eq!(output0_shape, &[1, 1000]);
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
assert_eq!(input0_shape, &[1, 3, 224, 224]);
assert_eq!(output0_shape, &[1, 1000]);
input0_shape
};
// Load image and resize to model's shape, converting to RGB format
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))

View File

@@ -52,18 +52,20 @@ fn upsample() -> ort::Result<()> {
let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_memory(&session_data)
.expect("Could not read model from memory");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "tf2onnx");
assert_eq!(metadata.producer()?, "tf2onnx");
{
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "tf2onnx");
assert_eq!(metadata.producer()?, "tf2onnx");
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
}
// Load image, converting to RGB format
let image_buffer = load_input_image(IMAGE_TO_LOAD);
@@ -93,7 +95,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file");
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_memory_directly(&session_data) // Zero-copy.

View File

@@ -12,17 +12,19 @@ use test_log::test;
#[test]
fn vectorizer() -> ort::Result<()> {
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))
.expect("Could not load model");
let metadata = session.metadata()?;
assert_eq!(metadata.producer()?, "skl2onnx");
assert_eq!(metadata.description()?, "test description");
assert_eq!(metadata.custom_keys()?, ["custom_key"]);
assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value"));
{
let metadata = session.metadata()?;
assert_eq!(metadata.producer()?, "skl2onnx");
assert_eq!(metadata.description()?, "test description");
assert_eq!(metadata.custom_keys()?, ["custom_key"]);
assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value"));
}
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());