mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor!: make Session::run take &mut self
This commit is contained in:
@@ -15,21 +15,25 @@ fn mnist_5() -> ort::Result<()> {
|
||||
|
||||
ort::init().with_name("integration_test").commit()?;
|
||||
|
||||
let session = Session::builder()?
|
||||
let mut session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")
|
||||
.expect("Could not download model from file");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "CNTKGraph");
|
||||
assert_eq!(metadata.producer()?, "CNTK");
|
||||
let input0_shape = {
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "CNTKGraph");
|
||||
assert_eq!(metadata.producer()?, "CNTK");
|
||||
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
|
||||
|
||||
assert_eq!(input0_shape, &[1, 1, 28, 28]);
|
||||
assert_eq!(output0_shape, &[1, 10]);
|
||||
assert_eq!(input0_shape, &[1, 1, 28, 28]);
|
||||
assert_eq!(output0_shape, &[1, 10]);
|
||||
|
||||
input0_shape
|
||||
};
|
||||
|
||||
// Load image and resize to model's shape, converting to RGB format
|
||||
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))
|
||||
|
||||
@@ -21,23 +21,27 @@ fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
|
||||
ort::init().with_name("integration_test").commit()?;
|
||||
|
||||
let session = Session::builder()?
|
||||
let mut session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx")
|
||||
.expect("Could not download model from file");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "main_graph");
|
||||
assert_eq!(metadata.producer()?, "pytorch");
|
||||
|
||||
let class_labels = get_imagenet_labels()?;
|
||||
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
|
||||
let input0_shape = {
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "main_graph");
|
||||
assert_eq!(metadata.producer()?, "pytorch");
|
||||
|
||||
assert_eq!(input0_shape, &[1, 3, 224, 224]);
|
||||
assert_eq!(output0_shape, &[1, 1000]);
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
let output0_shape: &Vec<i64> = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type");
|
||||
|
||||
assert_eq!(input0_shape, &[1, 3, 224, 224]);
|
||||
assert_eq!(output0_shape, &[1, 1000]);
|
||||
|
||||
input0_shape
|
||||
};
|
||||
|
||||
// Load image and resize to model's shape, converting to RGB format
|
||||
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))
|
||||
|
||||
@@ -52,18 +52,20 @@ fn upsample() -> ort::Result<()> {
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
|
||||
let session = Session::builder()?
|
||||
let mut session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_memory(&session_data)
|
||||
.expect("Could not read model from memory");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "tf2onnx");
|
||||
assert_eq!(metadata.producer()?, "tf2onnx");
|
||||
{
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "tf2onnx");
|
||||
assert_eq!(metadata.producer()?, "tf2onnx");
|
||||
|
||||
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
}
|
||||
|
||||
// Load image, converting to RGB format
|
||||
let image_buffer = load_input_image(IMAGE_TO_LOAD);
|
||||
@@ -93,7 +95,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file");
|
||||
let session = Session::builder()?
|
||||
let mut session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_memory_directly(&session_data) // Zero-copy.
|
||||
|
||||
@@ -12,17 +12,19 @@ use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn vectorizer() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
let mut session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))
|
||||
.expect("Could not load model");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.producer()?, "skl2onnx");
|
||||
assert_eq!(metadata.description()?, "test description");
|
||||
assert_eq!(metadata.custom_keys()?, ["custom_key"]);
|
||||
assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value"));
|
||||
{
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.producer()?, "skl2onnx");
|
||||
assert_eq!(metadata.description()?, "test description");
|
||||
assert_eq!(metadata.custom_keys()?, ["custom_key"]);
|
||||
assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value"));
|
||||
}
|
||||
|
||||
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user