diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index eee1b33..207b650 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -4,7 +4,7 @@ use std::{ }; use ndarray::{array, concatenate, s, Array1, Axis}; -use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, Environment, GraphOptimizationLevel, SessionBuilder, Tensor}; +use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Tensor}; use rand::Rng; use tokenizers::Tokenizer; @@ -23,17 +23,17 @@ fn main() -> ort::Result<()> { // Initialize tracing to receive debug messages from `ort` tracing_subscriber::fmt::init(); + // Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process. + ort::init() + .with_name("GPT-2") + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .commit()?; + let mut stdout = io::stdout(); let mut rng = rand::thread_rng(); - // Create the ONNX Runtime environment and session for the GPT-2 model. - let environment = Environment::builder() - .with_name("GPT-2") - .with_execution_providers([CUDAExecutionProvider::default().build()]) - .build()? - .into_arc(); - - let session = SessionBuilder::new(&environment)? + // Load our model + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .with_model_downloaded(GPT2::GPT2LmHead)?; diff --git a/examples/yolov8/examples/yolov8.rs b/examples/yolov8/examples/yolov8.rs index 447dbea..b7fa423 100644 --- a/examples/yolov8/examples/yolov8.rs +++ b/examples/yolov8/examples/yolov8.rs @@ -4,7 +4,7 @@ use std::path::Path; use image::{imageops::FilterType, GenericImageView}; use ndarray::{s, Array, Axis}; -use ort::{inputs, CUDAExecutionProvider, Environment, SessionBuilder, SessionOutputs}; +use ort::{inputs, CUDAExecutionProvider, Session, SessionOutputs}; use raqote::{DrawOptions, DrawTarget, LineJoin, PathBuilder, SolidSource, Source, StrokeStyle}; use show_image::{event, AsImageView, WindowOptions}; @@ -42,6 +42,10 @@ const YOLOV8_CLASS_LABELS: [&str; 80] = [ fn main() -> ort::Result<()> { tracing_subscriber::fmt::init(); + ort::init() + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .commit()?; + let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("baseball.jpg")).unwrap(); let (img_width, img_height) = (original_img.width(), original_img.height()); let img = original_img.resize_exact(640, 640, FilterType::CatmullRom); @@ -55,14 +59,10 @@ fn main() -> ort::Result<()> { input[[0, 2, y, x]] = (b as f32) / 255.; } - let env = Environment::builder() - .with_execution_providers([CUDAExecutionProvider::default().build()]) - .build()? - .into_arc(); - let model = SessionBuilder::new(&env).unwrap().with_model_downloaded(YOLOV8M_URL).unwrap(); + let model = Session::builder()?.with_model_downloaded(YOLOV8M_URL)?; // Run YOLOv8 inference - let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?).unwrap(); + let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?; let output = outputs["output0"].extract_tensor::().unwrap().view().t().into_owned(); let mut boxes = Vec::new(); diff --git a/src/environment.rs b/src/environment.rs index 8b7e0bc..78a97b5 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -1,28 +1,31 @@ use std::{ ffi::CString, - sync::{atomic::AtomicPtr, Arc, Mutex} + sync::{atomic::AtomicPtr, OnceLock} }; -use once_cell::sync::Lazy; -use tracing::{debug, error, warn}; +use tracing::debug; use super::{ custom_logger, - error::{status_to_result, Error, Result}, - ort, ortsys, ExecutionProviderDispatch, LoggingLevel + error::{Error, Result}, + ortsys, ExecutionProviderDispatch, LoggingLevel }; -static G_ENV: Lazy>> = Lazy::new(|| { - Arc::new(Mutex::new(EnvironmentSingleton { - name: String::from("uninitialized"), - env_ptr: AtomicPtr::new(std::ptr::null_mut()) - })) -}); +static G_ENV: OnceLock = OnceLock::new(); #[derive(Debug)] -struct EnvironmentSingleton { - name: String, - env_ptr: AtomicPtr +pub(crate) struct EnvironmentSingleton { + pub(crate) execution_providers: Vec, + pub(crate) env_ptr: AtomicPtr +} + +pub(crate) fn get_environment() -> Result<&'static EnvironmentSingleton> { + if G_ENV.get().is_none() { + EnvironmentBuilder::default().commit()?; + Ok(G_ENV.get().unwrap()) + } else { + Ok(unsafe { G_ENV.get().unwrap_unchecked() }) + } } #[derive(Debug, Default, Clone)] @@ -33,145 +36,6 @@ pub struct EnvironmentGlobalThreadPoolOptions { pub intra_op_thread_affinity: Option } -/// An [`Environment`] is the main entry point of the ONNX Runtime. -/// -/// Only one ONNX environment can be created per process. A singleton is used to enforce this. -/// -/// Once an environment is created, a [`super::Session`] can be obtained from it. -/// -/// **NOTE**: While the [`Environment`] constructor takes a `name` parameter to name the environment, only the first -/// name will be considered if many environments are created. -/// -/// # Example -/// -/// ```no_run -/// # use std::error::Error; -/// # use ort::{Environment, LoggingLevel}; -/// # fn main() -> Result<(), Box> { -/// let environment = Environment::builder().with_name("test").with_log_level(LoggingLevel::Verbose).build()?; -/// # Ok(()) -/// # } -/// ``` -#[derive(Debug, Clone)] -pub struct Environment { - env: Arc>, - pub(crate) execution_providers: Vec -} - -unsafe impl Send for Environment {} -unsafe impl Sync for Environment {} - -impl Environment { - /// Create a new environment builder using default values - /// (name: `default`, log level: [`LoggingLevel::Warning`]) - pub fn builder() -> EnvironmentBuilder { - EnvironmentBuilder { - name: "default".into(), - log_level: LoggingLevel::Warning, - execution_providers: Vec::new(), - global_thread_pool_options: None - } - } - - /// Return the name of the current environment - pub fn name(&self) -> String { - self.env.lock().unwrap().name.to_string() - } - - /// Wraps this environment in an `Arc` for use with `SessionBuilder`. - pub fn into_arc(self) -> Arc { - Arc::new(self) - } - - pub(crate) fn ptr(&self) -> *const ort_sys::OrtEnv { - *self.env.lock().unwrap().env_ptr.get_mut() - } -} - -impl Default for Environment { - fn default() -> Self { - // NOTE: Because 'G_ENV' is `Lazy`, locking it will, initially, create - // a new Arc> with a strong count of 1. - // Cloning it to embed it inside the 'Environment' to return - // will thus increase the strong count to 2. - let mut environment_guard = G_ENV.lock().expect("Failed to acquire global environment lock: another thread panicked?"); - let g_env_ptr = environment_guard.env_ptr.get_mut(); - if g_env_ptr.is_null() { - debug!("Environment not yet initialized, creating a new one"); - - let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); - - let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); - // FIXME: What should go here? - let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); - - let cname = CString::new("default".to_string()).unwrap(); - - status_to_result( - ortsys![unsafe CreateEnvWithCustomLogger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr); nonNull(env_ptr)] - ) - .map_err(Error::CreateEnvironment) - .unwrap(); - - debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created"); - - *g_env_ptr = env_ptr; - environment_guard.name = "default".to_string(); - - // NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one. - // If this 'Environment' is the only one in the process, the strong count - // will be 2: - // * one `Lazy` 'G_ENV' - // * one inside the 'Environment' returned - Environment { - env: G_ENV.clone(), - execution_providers: vec![] - } - } else { - // NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one. - // If this 'Environment' is the only one in the process, the strong count - // will be 2: - // * one `Lazy` 'G_ENV' - // * one inside the 'Environment' returned - Environment { - env: G_ENV.clone(), - execution_providers: vec![] - } - } - } -} - -impl Drop for Environment { - #[tracing::instrument] - fn drop(&mut self) { - debug!(global_arc_count = Arc::strong_count(&G_ENV), "Dropping environment"); - - let mut environment_guard = self.env.lock().expect("Failed to acquire lock: another thread panicked?"); - - // NOTE: If we drop an 'Environment' we (obviously) have _at least_ - // one 'G_ENV' strong count (the one in the 'env' member). - // There is also the "original" 'G_ENV' which is a the `Lazy` global. - // If there is no other environment, the strong count should be two and we - // can properly free the sys::OrtEnv pointer. - if Arc::strong_count(&G_ENV) == 2 { - let release_env = ort().ReleaseEnv.unwrap(); - let env_ptr: *mut ort_sys::OrtEnv = *environment_guard.env_ptr.get_mut(); - - debug!(global_arc_count = Arc::strong_count(&G_ENV), "Releasing environment"); - - assert_ne!(env_ptr, std::ptr::null_mut()); - if env_ptr.is_null() { - error!("Environment pointer is null, not dropping!"); - } else { - unsafe { release_env(env_ptr) }; - } - - environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut()); - environment_guard.name = String::from("uninitialized"); - } - } -} - /// Struct used to build an environment [`Environment`]. /// /// This is ONNX Runtime's main entry point. An environment _must_ be created as the first step. An [`Environment`] can @@ -189,6 +53,17 @@ pub struct EnvironmentBuilder { global_thread_pool_options: Option } +impl Default for EnvironmentBuilder { + fn default() -> Self { + EnvironmentBuilder { + name: "default".to_string(), + log_level: LoggingLevel::Error, + execution_providers: vec![], + global_thread_pool_options: None + } + } +} + impl EnvironmentBuilder { /// Configure the environment with a given name /// @@ -263,14 +138,8 @@ impl EnvironmentBuilder { } /// Commit the configuration to a new [`Environment`]. - pub fn build(self) -> Result { - // NOTE: Because 'G_ENV' is a `Lazy`, locking it will, initially, create - // a new Arc> with a strong count of 1. - // Cloning it to embed it inside the 'Environment' to return - // will thus increase the strong count to 2. - let mut environment_guard = G_ENV.lock().expect("Failed to acquire global environment lock: another thread panicked?"); - let g_env_ptr = environment_guard.env_ptr.get_mut(); - if g_env_ptr.is_null() { + pub fn commit(self) -> Result<()> { + if G_ENV.get().is_none() { debug!("Environment not yet initialized, creating a new one"); let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options { @@ -309,41 +178,22 @@ impl EnvironmentBuilder { }; debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created"); - *g_env_ptr = env_ptr; - environment_guard.name = self.name; - - // NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one. - // If this 'Environment' is the only one in the process, the strong count - // will be 2: - // * one `Lazy` 'G_ENV' - // * one inside the 'Environment' returned - Ok(Environment { - env: G_ENV.clone(), - execution_providers: self.execution_providers - }) - } else { - warn!( - name = environment_guard.name.as_str(), - env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(), - "Environment already initialized for this thread, reusing it", - ); - - // NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one. - // If this 'Environment' is the only one in the process, the strong count - // will be 2: - // * one `Lazy` 'G_ENV' - // * one inside the 'Environment' returned - Ok(Environment { - env: G_ENV.clone(), - execution_providers: self.execution_providers.clone() - }) + let _ = G_ENV.set(EnvironmentSingleton { + execution_providers: self.execution_providers, + env_ptr: AtomicPtr::new(env_ptr) + }); } + Ok(()) } } +pub fn init() -> EnvironmentBuilder { + EnvironmentBuilder::default() +} + #[cfg(test)] mod tests { - use std::sync::{RwLock, RwLockWriteGuard}; + use std::sync::{atomic::Ordering, Arc, RwLock, RwLockWriteGuard}; use once_cell::sync::Lazy; use test_log::test; @@ -351,11 +201,11 @@ mod tests { use super::*; fn is_env_initialized() -> bool { - Arc::strong_count(&G_ENV) >= 2 + G_ENV.get().is_some() && !G_ENV.get().unwrap().env_ptr.load(Ordering::Relaxed).is_null() } - fn env_ptr() -> *const ort_sys::OrtEnv { - *G_ENV.lock().unwrap().env_ptr.get_mut() + fn env_ptr() -> Option<*mut ort_sys::OrtEnv> { + G_ENV.get().map(|f| f.env_ptr.load(Ordering::Relaxed)) } struct ConcurrentTestRun { @@ -373,76 +223,14 @@ mod tests { let _run_lock = single_test_run(); assert!(!is_env_initialized()); - assert_eq!(env_ptr(), std::ptr::null_mut()); + assert_eq!(env_ptr(), None); - let env = Environment::builder() + EnvironmentBuilder::default() .with_name("env_is_initialized") .with_log_level(LoggingLevel::Warning) - .build() + .commit() .unwrap(); assert!(is_env_initialized()); - assert_ne!(env_ptr(), std::ptr::null_mut()); - - drop(env); - assert!(!is_env_initialized()); - assert_eq!(env_ptr(), std::ptr::null_mut()); - } - - #[ignore] - #[test] - fn sequential_environment_creation() { - let _concurrent_run_lock_guard = single_test_run(); - - let mut prev_env_ptr = env_ptr(); - - for i in 0..10 { - let name = format!("sequential_environment_creation: {}", i); - let env = Environment::builder() - .with_name(name.clone()) - .with_log_level(LoggingLevel::Warning) - .build() - .unwrap(); - let next_env_ptr = env_ptr(); - assert_ne!(next_env_ptr, prev_env_ptr); - prev_env_ptr = next_env_ptr; - - assert_eq!(env.name(), name); - } - } - - #[test] - fn concurrent_environment_creations() { - let _concurrent_run_lock_guard = single_test_run(); - - let initial_name = String::from("concurrent_environment_creation"); - let main_env = Environment::builder() - .with_name(&initial_name) - .with_log_level(LoggingLevel::Warning) - .build() - .unwrap(); - let main_env_ptr = main_env.ptr() as usize; - - assert_eq!(main_env.name(), initial_name); - assert_eq!(main_env.ptr() as usize, main_env_ptr); - - assert!( - (0..10) - .map(|t| { - let initial_name_cloned = initial_name.clone(); - std::thread::spawn(move || { - let name = format!("concurrent_environment_creation: {}", t); - let env = Environment::builder() - .with_name(name) - .with_log_level(LoggingLevel::Warning) - .build() - .unwrap(); - - assert_eq!(env.name(), initial_name_cloned); - assert_eq!(env.ptr() as usize, main_env_ptr); - }) - }) - .map(|child| child.join()) - .all(|r| Result::is_ok(&r)) - ); + assert_ne!(env_ptr(), None); } } diff --git a/src/error.rs b/src/error.rs index cda822c..9ff5679 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,7 @@ use std::{convert::Infallible, io, path::PathBuf, string}; use thiserror::Error; -use super::{char_p_to_string, ort, tensor::TensorElementDataType}; +use super::{char_p_to_string, ortsys, tensor::TensorElementDataType}; /// Type alias for the Result type returned by ORT functions. pub type Result = std::result::Result; @@ -281,7 +281,7 @@ impl From for Result<(), ErrorInternal> { if status.0.is_null() { Ok(()) } else { - let raw: *const std::os::raw::c_char = unsafe { ort().GetErrorMessage.unwrap()(status.0) }; + let raw: *const std::os::raw::c_char = ortsys![unsafe GetErrorMessage(status.0)]; match char_p_to_string(raw) { Ok(msg) => Err(ErrorInternal::Msg(msg)), Err(err) => match err { @@ -295,7 +295,7 @@ impl From for Result<(), ErrorInternal> { impl Drop for OrtStatusWrapper { fn drop(&mut self) { - unsafe { ort().ReleaseStatus.unwrap()(self.0) } + ortsys![unsafe ReleaseStatus(self.0)]; } } diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index d610251..47d2d20 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -15,6 +15,9 @@ pub struct OpenVINOExecutionProvider { enable_vpu_fast_compile: bool } +unsafe impl Send for OpenVINOExecutionProvider {} +unsafe impl Sync for OpenVINOExecutionProvider {} + impl Default for OpenVINOExecutionProvider { fn default() -> Self { Self { diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index 4d808a0..dfbdf1c 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -17,6 +17,9 @@ pub struct ROCmExecutionProvider { tunable_op_max_tuning_duration_ms: i32 } +unsafe impl Send for ROCmExecutionProvider {} +unsafe impl Sync for ROCmExecutionProvider {} + impl Default for ROCmExecutionProvider { fn default() -> Self { Self { diff --git a/src/lib.rs b/src/lib.rs index 87e2eb4..c60638c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,7 +31,7 @@ use std::{ use once_cell::sync::Lazy; use tracing::warn; -pub use self::environment::{Environment, EnvironmentBuilder}; +pub use self::environment::{init, EnvironmentBuilder}; #[cfg(feature = "fetch-models")] #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] pub use self::error::FetchModelError; diff --git a/src/session/input.rs b/src/session/input.rs index 8996a2f..15b91a2 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -41,10 +41,9 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> { /// ```no_run /// # use std::{error::Error, sync::Arc}; /// # use ndarray::Array1; -/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder}; +/// # use ort::{GraphOptimizationLevel, Session}; /// # fn main() -> Result<(), Box> { -/// # let environment = Environment::default().into_arc(); -/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?; +/// let mut session = Session::builder()?.with_model_from_file("model.onnx")?; /// let _ = session.run(ort::inputs![Array1::from_vec(vec![1, 2, 3, 4, 5])]?); /// # Ok(()) /// # } @@ -55,10 +54,9 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> { /// ```no_run /// # use std::{error::Error, sync::Arc}; /// # use ndarray::Array1; -/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder}; +/// # use ort::{GraphOptimizationLevel, Session}; /// # fn main() -> Result<(), Box> { -/// # let environment = Environment::default().into_arc(); -/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?; +/// let mut session = Session::builder()?.with_model_from_file("model.onnx")?; /// let _ = session.run(ort::inputs! { /// "tokens" => Array1::from_vec(vec![1, 2, 3, 4, 5]) /// }?); diff --git a/src/session/mod.rs b/src/session/mod.rs index fe71304..fc5d486 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -6,11 +6,20 @@ use std::os::unix::ffi::OsStrExt; use std::os::windows::ffi::OsStrExt; #[cfg(feature = "fetch-models")] use std::{env, path::PathBuf, time::Duration}; -use std::{ffi::CString, fmt, marker::PhantomData, ops::Deref, os::raw::c_char, path::Path, ptr, sync::Arc}; +use std::{ + ffi::CString, + fmt, + marker::PhantomData, + ops::Deref, + os::raw::c_char, + path::Path, + ptr, + sync::{atomic::Ordering, Arc} +}; use super::{ char_p_to_string, - environment::Environment, + environment::get_environment, error::{assert_non_null_pointer, assert_null_pointer, status_to_result, Error, ErrorInternal, Result}, execution_providers::{apply_execution_providers, ExecutionProviderDispatch}, extern_system_fn, @@ -28,8 +37,8 @@ pub(crate) mod input; pub(crate) mod output; pub use self::{input::SessionInputs, output::SessionOutputs}; -/// Type used to create a session using the _builder pattern_. Once created, you can use the different methods to -/// configure the session. +/// Type used to create a session using the _builder pattern_. Once created with [`Session::builder`], you can use the +/// different methods to configure the session. /// /// Once configured, use the [`SessionBuilder::with_model_from_file`](crate::SessionBuilder::with_model_from_file) /// method to "commit" the builder configuration into a [`Session`]. @@ -38,14 +47,9 @@ pub use self::{input::SessionInputs, output::SessionOutputs}; /// /// ```no_run /// # use std::{error::Error, sync::Arc}; -/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder}; +/// # use ort::{GraphOptimizationLevel, Session}; /// # fn main() -> Result<(), Box> { -/// let environment = Environment::builder() -/// .with_name("test") -/// .with_log_level(LoggingLevel::Verbose) -/// .build()? -/// .into_arc(); -/// let mut session = SessionBuilder::new(&environment)? +/// let mut session = Session::builder()? /// .with_optimization_level(GraphOptimizationLevel::Level1)? /// .with_intra_threads(1)? /// .with_model_from_file("squeezenet.onnx")?; @@ -59,16 +63,12 @@ pub struct SessionBuilder { memory_type: MemType, #[cfg(feature = "custom-ops")] custom_runtime_handles: Vec<*mut std::os::raw::c_void>, - execution_providers: Vec, - - // env must be last to drop it after everything else - env: Arc + execution_providers: Vec } impl fmt::Debug for SessionBuilder { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("SessionBuilder") - .field("env", &self.env.name()) .field("allocator", &self.allocator) .field("memory_type", &self.memory_type) .finish() @@ -81,7 +81,6 @@ impl Clone for SessionBuilder { status_to_result(ortsys![unsafe CloneSessionOptions(self.session_options_ptr, &mut session_options_ptr as *mut _)]) .expect("error cloning session options"); Self { - env: Arc::clone(&self.env), session_options_ptr, allocator: self.allocator, memory_type: self.memory_type, @@ -108,12 +107,11 @@ impl Drop for SessionBuilder { impl SessionBuilder { /// Creates a new session builder in the given environment. - pub fn new(env: &Arc) -> Result { + pub fn new() -> Result { let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = std::ptr::null_mut(); ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> Error::CreateSessionOptions; nonNull(session_options_ptr)]; Ok(Self { - env: Arc::clone(env), session_options_ptr, allocator: AllocatorType::Device, memory_type: MemType::Default, @@ -384,9 +382,10 @@ impl SessionBuilder { .map(|b| *b as std::os::raw::c_char) .collect(); - apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned()); + let env = get_environment()?; + apply_execution_providers(&self, self.execution_providers.iter().chain(&env.execution_providers).cloned()); - let env_ptr: *const ort_sys::OrtEnv = self.env.ptr(); + let env_ptr = env.env_ptr.load(Ordering::Relaxed); let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; @@ -404,90 +403,7 @@ impl SessionBuilder { .collect::>>()?; Ok(Session { - inner: Arc::new(SharedSessionInner { - env: Arc::clone(&self.env), - session_ptr, - allocator - }), - inputs, - outputs - }) - } - - /// Loads an ONNX model from a file, replacing external data with data provided in initializers. - /// - /// This will find initialized tensors with external data in the graph with the provided names and replace them with - /// the provided tensors. The replacement will occur before any optimizations take place, and the data will be - /// copied into the graph. Tensors replaced by this function must be using external data. (you cannot replace a - /// non-external tensor) - pub fn with_model_from_file_and_external_initializers<'v, 'i, P>(self, model_filepath_ref: P, initializers: &'i [(String, Value)]) -> Result - where - 'i: 'v, - P: AsRef - { - let model_filepath = model_filepath_ref.as_ref(); - if !model_filepath.exists() { - return Err(Error::FileDoesNotExist { - filename: model_filepath.to_path_buf() - }); - } - - // Build an OsString, then a vector of bytes to pass to C - let model_path = std::ffi::OsString::from(model_filepath); - #[cfg(target_family = "windows")] - let model_path: Vec = model_path - .encode_wide() - .chain(std::iter::once(0)) // Make sure we have a null terminated string - .collect(); - #[cfg(not(target_family = "windows"))] - let model_path: Vec = model_path - .as_bytes() - .iter() - .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string - .map(|b| *b as std::os::raw::c_char) - .collect(); - - apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned()); - - let env_ptr: *const ort_sys::OrtEnv = self.env.ptr(); - - let allocator = Allocator::default(); - - let initializer_names: Vec = initializers - .iter() - .map(|(name, _)| CString::new(name.as_str()).unwrap()) - .map(|n| CString::new(n).unwrap()) - .collect(); - let initializer_names_ptr: Vec<*const c_char> = initializer_names.iter().map(|n| n.as_ptr() as *const c_char).collect(); - - let initializers: Vec<*const ort_sys::OrtValue> = initializers.iter().map(|input_array_ort| input_array_ort.1.ptr() as *const _).collect(); - if !initializers.is_empty() { - assert_eq!(initializer_names.len(), initializers.len()); - ortsys![unsafe AddExternalInitializers(self.session_options_ptr, initializer_names_ptr.as_ptr(), initializers.as_ptr(), initializers.len() as _) -> Error::CreateSession]; - } - - let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); - ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; - - std::mem::drop(initializer_names); - std::mem::drop(initializers); - - // Extract input and output properties - let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; - let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; - let inputs = (0..num_input_nodes) - .map(|i| dangerous::extract_input(session_ptr, allocator.ptr, i)) - .collect::>>()?; - let outputs = (0..num_output_nodes) - .map(|i| dangerous::extract_output(session_ptr, allocator.ptr, i)) - .collect::>>()?; - - Ok(Session { - inner: Arc::new(SharedSessionInner { - env: Arc::clone(&self.env), - session_ptr, - allocator - }), + inner: Arc::new(SharedSessionInner { session_ptr, allocator }), inputs, outputs }) @@ -496,8 +412,9 @@ impl SessionBuilder { /// Load an ONNX graph from memory and commit the session /// For `.ort` models, we enable `session.use_ort_model_bytes_directly`. /// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array). - /// If you want to store the model file and the [`InMemorySession`] in same struct, - /// please check crates for creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). + /// + /// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that + /// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). pub fn with_model_from_memory_directly(self, model_bytes: &[u8]) -> Result> { let str_to_char = |s: &str| { s.as_bytes() @@ -519,9 +436,10 @@ impl SessionBuilder { pub fn with_model_from_memory(self, model_bytes: &[u8]) -> Result { let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); - let env_ptr: *const ort_sys::OrtEnv = self.env.ptr(); + let env = get_environment()?; + apply_execution_providers(&self, self.execution_providers.iter().chain(&env.execution_providers).cloned()); - apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned()); + let env_ptr = env.env_ptr.load(Ordering::Relaxed); let model_data = model_bytes.as_ptr() as *const std::ffi::c_void; let model_data_length = model_bytes.len(); @@ -543,11 +461,7 @@ impl SessionBuilder { .collect::>>()?; let session = Session { - inner: Arc::new(SharedSessionInner { - env: Arc::clone(&self.env), - session_ptr, - allocator - }), + inner: Arc::new(SharedSessionInner { session_ptr, allocator }), inputs, outputs }; @@ -560,11 +474,7 @@ impl SessionBuilder { #[derive(Debug)] pub struct SharedSessionInner { pub(crate) session_ptr: *mut ort_sys::OrtSession, - allocator: Allocator, - // hold onto an environment arc to ensure the environment also stays alive - // env must be last to drop it after everything else - #[allow(dead_code)] - env: Arc + allocator: Allocator } unsafe impl Send for SharedSessionInner {} @@ -624,6 +534,10 @@ pub struct Output { } impl Session { + pub fn builder() -> Result { + SessionBuilder::new() + } + /// Returns this session's [`Allocator`]. pub fn allocator(&self) -> &Allocator { &self.inner.allocator diff --git a/src/value.rs b/src/value.rs index cadfb54..7a41138 100644 --- a/src/value.rs +++ b/src/value.rs @@ -330,9 +330,7 @@ impl Value { }) } - /// Construct a [`Value`] from a Rust-owned [`CowArray`]. - /// - /// `allocator` is required to be `Some` when converting a String tensor. See [`crate::Session::allocator`]. + /// Construct a [`Value`] from a Rust-owned array. pub fn from_string_array(allocator: &Allocator, input: impl OrtInput) -> Result { let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemType::Default)?; diff --git a/tests/mnist.rs b/tests/mnist.rs index c7c12f8..bb20c5e 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -1,22 +1,16 @@ use std::path::Path; use image::{imageops::FilterType, ImageBuffer, Luma, Pixel}; -use ort::{ - download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor -}; +use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, LoggingLevel, Session, Tensor}; use test_log::test; #[test] fn mnist_5() -> ort::Result<()> { const IMAGE_TO_LOAD: &str = "mnist_5.jpg"; - let environment = Environment::builder() - .with_name("integration_test") - .with_log_level(LoggingLevel::Warning) - .build()? - .into_arc(); + ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?; - let session = SessionBuilder::new(&environment)? + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .with_model_downloaded(DomainBasedImageClassification::Mnist) diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index 871cd05..28d24e1 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -7,22 +7,16 @@ use std::{ use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb}; use ndarray::s; -use ort::{ - download::vision::ImageClassification, inputs, ArrayExtensions, Environment, FetchModelError, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor -}; +use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, LoggingLevel, Session, Tensor}; use test_log::test; #[test] fn squeezenet_mushroom() -> ort::Result<()> { const IMAGE_TO_LOAD: &str = "mushroom.png"; - let environment = Environment::builder() - .with_name("integration_test") - .with_log_level(LoggingLevel::Warning) - .build()? - .into_arc(); + ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?; - let session = SessionBuilder::new(&environment)? + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .with_model_downloaded(ImageClassification::SqueezeNet) diff --git a/tests/upsample.rs b/tests/upsample.rs index 042297c..e99f20a 100644 --- a/tests/upsample.rs +++ b/tests/upsample.rs @@ -2,7 +2,7 @@ use std::path::Path; use image::RgbImage; use ndarray::{Array, CowArray, Ix4}; -use ort::{inputs, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor}; +use ort::{inputs, GraphOptimizationLevel, LoggingLevel, Session, Tensor}; use test_log::test; fn load_input_image>(name: P) -> RgbImage { @@ -44,15 +44,11 @@ fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, Ix4> { fn upsample() -> ort::Result<()> { const IMAGE_TO_LOAD: &str = "mushroom.png"; - let environment = Environment::builder() - .with_name("integration_test") - .with_log_level(LoggingLevel::Warning) - .build()? - .into_arc(); + ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?; let session_data = std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file"); - let session = SessionBuilder::new(&environment)? + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .with_model_from_memory(&session_data) @@ -89,15 +85,11 @@ fn upsample() -> ort::Result<()> { fn upsample_with_ort_model() -> ort::Result<()> { const IMAGE_TO_LOAD: &str = "mushroom.png"; - let environment = Environment::builder() - .with_name("integration_test") - .with_log_level(LoggingLevel::Warning) - .build()? - .into_arc(); + ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?; let session_data = std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file"); - let session = SessionBuilder::new(&environment)? + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .with_model_from_memory_directly(&session_data) // Zero-copy. diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index 357c3cc..700cf86 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -1,15 +1,13 @@ use std::path::Path; use ndarray::{ArrayD, IxDyn}; -use ort::{inputs, Environment, GraphOptimizationLevel, SessionBuilder, Value}; +use ort::{inputs, GraphOptimizationLevel, Session, Value}; use test_log::test; #[test] #[cfg(not(target_arch = "aarch64"))] fn vectorizer() -> ort::Result<()> { - let environment = Environment::default().into_arc(); - - let session = SessionBuilder::new(&environment)? + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))