refactor: make Environment easier to use

This commit is contained in:
Carson M
2023-01-17 10:16:18 -06:00
parent b49ca6f94d
commit 1ebefceb0b
2 changed files with 71 additions and 20 deletions

View File

@@ -67,6 +67,11 @@ impl Environment {
self.env.lock().unwrap().name.to_string()
}
/// Wraps this environment in an `Arc` for use with `SessionBuilder`.
pub fn into_arc(self) -> Arc<Environment> {
Arc::new(self)
}
pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
*self.env.lock().unwrap().env_ptr.get_mut()
}
@@ -127,6 +132,57 @@ impl Environment {
}
}
impl Default for Environment {
fn default() -> Self {
// NOTE: Because 'G_ENV' is a lazy_static, locking it will, initially, create
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
// Cloning it to embed it inside the 'Environment' to return
// will thus increase the strong count to 2.
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one.");
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new("default".to_string()).unwrap();
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr) };
status_to_result(status).map_err(OrtError::CreateEnvironment).unwrap();
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created.");
*g_env_ptr = env_ptr;
environment_guard.name = "default".to_string();
// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one lazy_static 'G_ENV'
// * one inside the 'Environment' returned
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
} else {
// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one lazy_static 'G_ENV'
// * one inside the 'Environment' returned
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
}
}
}
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {

View File

@@ -8,8 +8,6 @@ use std::{
use ort::error::OrtDownloadError;
mod download {
use std::sync::Arc;
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel, Rgb};
use ndarray::s;
use ort::{
@@ -26,12 +24,11 @@ mod download {
fn squeezenet_mushroom() -> OrtResult<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Arc::new(
Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
);
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
let session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
@@ -106,12 +103,11 @@ mod download {
fn mnist_5() -> OrtResult<()> {
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
let environment = Arc::new(
Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
);
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
let session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
@@ -185,12 +181,11 @@ mod download {
fn upsample() -> OrtResult<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Arc::new(
Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
);
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
let session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?