mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: create environment in global OnceLock
This commit is contained in:
@@ -1,22 +1,16 @@
|
||||
use std::path::Path;
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel};
|
||||
use ort::{
|
||||
download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor
|
||||
};
|
||||
use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn mnist_5() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded(DomainBasedImageClassification::Mnist)
|
||||
|
||||
@@ -7,22 +7,16 @@ use std::{
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
|
||||
use ndarray::s;
|
||||
use ort::{
|
||||
download::vision::ImageClassification, inputs, ArrayExtensions, Environment, FetchModelError, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor
|
||||
};
|
||||
use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded(ImageClassification::SqueezeNet)
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::path::Path;
|
||||
|
||||
use image::RgbImage;
|
||||
use ndarray::{Array, CowArray, Ix4};
|
||||
use ort::{inputs, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor};
|
||||
use ort::{inputs, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
fn load_input_image<P: AsRef<Path>>(name: P) -> RgbImage {
|
||||
@@ -44,15 +44,11 @@ fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, Ix4> {
|
||||
fn upsample() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_memory(&session_data)
|
||||
@@ -89,15 +85,11 @@ fn upsample() -> ort::Result<()> {
|
||||
fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file");
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_memory_directly(&session_data) // Zero-copy.
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
use std::path::Path;
|
||||
|
||||
use ndarray::{ArrayD, IxDyn};
|
||||
use ort::{inputs, Environment, GraphOptimizationLevel, SessionBuilder, Value};
|
||||
use ort::{inputs, GraphOptimizationLevel, Session, Value};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
fn vectorizer() -> ort::Result<()> {
|
||||
let environment = Environment::default().into_arc();
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))
|
||||
|
||||
Reference in New Issue
Block a user