mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: make Environment easier to use
This commit is contained in:
@@ -67,6 +67,11 @@ impl Environment {
|
||||
self.env.lock().unwrap().name.to_string()
|
||||
}
|
||||
|
||||
/// Wraps this environment in an `Arc` for use with `SessionBuilder`.
|
||||
pub fn into_arc(self) -> Arc<Environment> {
|
||||
Arc::new(self)
|
||||
}
|
||||
|
||||
pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
|
||||
*self.env.lock().unwrap().env_ptr.get_mut()
|
||||
}
|
||||
@@ -127,6 +132,57 @@ impl Environment {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Environment {
|
||||
fn default() -> Self {
|
||||
// NOTE: Because 'G_ENV' is a lazy_static, locking it will, initially, create
|
||||
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
|
||||
// Cloning it to embed it inside the 'Environment' to return
|
||||
// will thus increase the strong count to 2.
|
||||
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
|
||||
let g_env_ptr = environment_guard.env_ptr.get_mut();
|
||||
if g_env_ptr.is_null() {
|
||||
debug!("Environment not yet initialized, creating a new one.");
|
||||
|
||||
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
|
||||
|
||||
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
|
||||
// FIXME: What should go here?
|
||||
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
|
||||
let cname = CString::new("default".to_string()).unwrap();
|
||||
|
||||
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
|
||||
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr) };
|
||||
status_to_result(status).map_err(OrtError::CreateEnvironment).unwrap();
|
||||
|
||||
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created.");
|
||||
|
||||
*g_env_ptr = env_ptr;
|
||||
environment_guard.name = "default".to_string();
|
||||
|
||||
// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
|
||||
// If this 'Environment' is the only one in the process, the strong count
|
||||
// will be 2:
|
||||
// * one lazy_static 'G_ENV'
|
||||
// * one inside the 'Environment' returned
|
||||
Environment {
|
||||
env: G_ENV.clone(),
|
||||
execution_providers: vec![]
|
||||
}
|
||||
} else {
|
||||
// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
|
||||
// If this 'Environment' is the only one in the process, the strong count
|
||||
// will be 2:
|
||||
// * one lazy_static 'G_ENV'
|
||||
// * one inside the 'Environment' returned
|
||||
Environment {
|
||||
env: G_ENV.clone(),
|
||||
execution_providers: vec![]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Environment {
|
||||
#[tracing::instrument]
|
||||
fn drop(&mut self) {
|
||||
|
||||
@@ -8,8 +8,6 @@ use std::{
|
||||
use ort::error::OrtDownloadError;
|
||||
|
||||
mod download {
|
||||
use std::sync::Arc;
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel, Rgb};
|
||||
use ndarray::s;
|
||||
use ort::{
|
||||
@@ -26,12 +24,11 @@ mod download {
|
||||
fn squeezenet_mushroom() -> OrtResult<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Arc::new(
|
||||
Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
);
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
@@ -106,12 +103,11 @@ mod download {
|
||||
fn mnist_5() -> OrtResult<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
|
||||
|
||||
let environment = Arc::new(
|
||||
Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
);
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
@@ -185,12 +181,11 @@ mod download {
|
||||
fn upsample() -> OrtResult<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Arc::new(
|
||||
Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
);
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
|
||||
Reference in New Issue
Block a user