refactor: create environment in global OnceLock

This commit is contained in:
Carson M.
2023-11-20 21:05:50 -06:00
parent 9b1bc6a54b
commit c69064f443
14 changed files with 124 additions and 442 deletions

View File

@@ -4,7 +4,7 @@ use std::{
};
use ndarray::{array, concatenate, s, Array1, Axis};
use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, Environment, GraphOptimizationLevel, SessionBuilder, Tensor};
use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Tensor};
use rand::Rng;
use tokenizers::Tokenizer;
@@ -23,17 +23,17 @@ fn main() -> ort::Result<()> {
// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::fmt::init();
// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
ort::init()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;
let mut stdout = io::stdout();
let mut rng = rand::thread_rng();
// Create the ONNX Runtime environment and session for the GPT-2 model.
let environment = Environment::builder()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.build()?
.into_arc();
let session = SessionBuilder::new(&environment)?
// Load our model
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(GPT2::GPT2LmHead)?;

View File

@@ -4,7 +4,7 @@ use std::path::Path;
use image::{imageops::FilterType, GenericImageView};
use ndarray::{s, Array, Axis};
use ort::{inputs, CUDAExecutionProvider, Environment, SessionBuilder, SessionOutputs};
use ort::{inputs, CUDAExecutionProvider, Session, SessionOutputs};
use raqote::{DrawOptions, DrawTarget, LineJoin, PathBuilder, SolidSource, Source, StrokeStyle};
use show_image::{event, AsImageView, WindowOptions};
@@ -42,6 +42,10 @@ const YOLOV8_CLASS_LABELS: [&str; 80] = [
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();
ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;
let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("baseball.jpg")).unwrap();
let (img_width, img_height) = (original_img.width(), original_img.height());
let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
@@ -55,14 +59,10 @@ fn main() -> ort::Result<()> {
input[[0, 2, y, x]] = (b as f32) / 255.;
}
let env = Environment::builder()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.build()?
.into_arc();
let model = SessionBuilder::new(&env).unwrap().with_model_downloaded(YOLOV8M_URL).unwrap();
let model = Session::builder()?.with_model_downloaded(YOLOV8M_URL)?;
// Run YOLOv8 inference
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?).unwrap();
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?;
let output = outputs["output0"].extract_tensor::<f32>().unwrap().view().t().into_owned();
let mut boxes = Vec::new();

View File

@@ -1,28 +1,31 @@
use std::{
ffi::CString,
sync::{atomic::AtomicPtr, Arc, Mutex}
sync::{atomic::AtomicPtr, OnceLock}
};
use once_cell::sync::Lazy;
use tracing::{debug, error, warn};
use tracing::debug;
use super::{
custom_logger,
error::{status_to_result, Error, Result},
ort, ortsys, ExecutionProviderDispatch, LoggingLevel
error::{Error, Result},
ortsys, ExecutionProviderDispatch, LoggingLevel
};
static G_ENV: Lazy<Arc<Mutex<EnvironmentSingleton>>> = Lazy::new(|| {
Arc::new(Mutex::new(EnvironmentSingleton {
name: String::from("uninitialized"),
env_ptr: AtomicPtr::new(std::ptr::null_mut())
}))
});
static G_ENV: OnceLock<EnvironmentSingleton> = OnceLock::new();
#[derive(Debug)]
struct EnvironmentSingleton {
name: String,
env_ptr: AtomicPtr<ort_sys::OrtEnv>
pub(crate) struct EnvironmentSingleton {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
pub(crate) env_ptr: AtomicPtr<ort_sys::OrtEnv>
}
pub(crate) fn get_environment() -> Result<&'static EnvironmentSingleton> {
if G_ENV.get().is_none() {
EnvironmentBuilder::default().commit()?;
Ok(G_ENV.get().unwrap())
} else {
Ok(unsafe { G_ENV.get().unwrap_unchecked() })
}
}
#[derive(Debug, Default, Clone)]
@@ -33,145 +36,6 @@ pub struct EnvironmentGlobalThreadPoolOptions {
pub intra_op_thread_affinity: Option<String>
}
/// An [`Environment`] is the main entry point of the ONNX Runtime.
///
/// Only one ONNX environment can be created per process. A singleton is used to enforce this.
///
/// Once an environment is created, a [`super::Session`] can be obtained from it.
///
/// **NOTE**: While the [`Environment`] constructor takes a `name` parameter to name the environment, only the first
/// name will be considered if many environments are created.
///
/// # Example
///
/// ```no_run
/// # use std::error::Error;
/// # use ort::{Environment, LoggingLevel};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// let environment = Environment::builder().with_name("test").with_log_level(LoggingLevel::Verbose).build()?;
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct Environment {
env: Arc<Mutex<EnvironmentSingleton>>,
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>
}
unsafe impl Send for Environment {}
unsafe impl Sync for Environment {}
impl Environment {
/// Create a new environment builder using default values
/// (name: `default`, log level: [`LoggingLevel::Warning`])
pub fn builder() -> EnvironmentBuilder {
EnvironmentBuilder {
name: "default".into(),
log_level: LoggingLevel::Warning,
execution_providers: Vec::new(),
global_thread_pool_options: None
}
}
/// Return the name of the current environment
pub fn name(&self) -> String {
self.env.lock().unwrap().name.to_string()
}
/// Wraps this environment in an `Arc` for use with `SessionBuilder`.
pub fn into_arc(self) -> Arc<Environment> {
Arc::new(self)
}
pub(crate) fn ptr(&self) -> *const ort_sys::OrtEnv {
*self.env.lock().unwrap().env_ptr.get_mut()
}
}
impl Default for Environment {
fn default() -> Self {
// NOTE: Because 'G_ENV' is `Lazy`, locking it will, initially, create
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
// Cloning it to embed it inside the 'Environment' to return
// will thus increase the strong count to 2.
let mut environment_guard = G_ENV.lock().expect("Failed to acquire global environment lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one");
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new("default".to_string()).unwrap();
status_to_result(
ortsys![unsafe CreateEnvWithCustomLogger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr); nonNull(env_ptr)]
)
.map_err(Error::CreateEnvironment)
.unwrap();
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
*g_env_ptr = env_ptr;
environment_guard.name = "default".to_string();
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one `Lazy` 'G_ENV'
// * one inside the 'Environment' returned
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
} else {
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one `Lazy` 'G_ENV'
// * one inside the 'Environment' returned
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
}
}
}
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Dropping environment");
let mut environment_guard = self.env.lock().expect("Failed to acquire lock: another thread panicked?");
// NOTE: If we drop an 'Environment' we (obviously) have _at least_
// one 'G_ENV' strong count (the one in the 'env' member).
// There is also the "original" 'G_ENV' which is a the `Lazy` global.
// If there is no other environment, the strong count should be two and we
// can properly free the sys::OrtEnv pointer.
if Arc::strong_count(&G_ENV) == 2 {
let release_env = ort().ReleaseEnv.unwrap();
let env_ptr: *mut ort_sys::OrtEnv = *environment_guard.env_ptr.get_mut();
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Releasing environment");
assert_ne!(env_ptr, std::ptr::null_mut());
if env_ptr.is_null() {
error!("Environment pointer is null, not dropping!");
} else {
unsafe { release_env(env_ptr) };
}
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
environment_guard.name = String::from("uninitialized");
}
}
}
/// Struct used to build an environment [`Environment`].
///
/// This is ONNX Runtime's main entry point. An environment _must_ be created as the first step. An [`Environment`] can
@@ -189,6 +53,17 @@ pub struct EnvironmentBuilder {
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
}
impl Default for EnvironmentBuilder {
fn default() -> Self {
EnvironmentBuilder {
name: "default".to_string(),
log_level: LoggingLevel::Error,
execution_providers: vec![],
global_thread_pool_options: None
}
}
}
impl EnvironmentBuilder {
/// Configure the environment with a given name
///
@@ -263,14 +138,8 @@ impl EnvironmentBuilder {
}
/// Commit the configuration to a new [`Environment`].
pub fn build(self) -> Result<Environment> {
// NOTE: Because 'G_ENV' is a `Lazy`, locking it will, initially, create
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
// Cloning it to embed it inside the 'Environment' to return
// will thus increase the strong count to 2.
let mut environment_guard = G_ENV.lock().expect("Failed to acquire global environment lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
pub fn commit(self) -> Result<()> {
if G_ENV.get().is_none() {
debug!("Environment not yet initialized, creating a new one");
let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options {
@@ -309,41 +178,22 @@ impl EnvironmentBuilder {
};
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
*g_env_ptr = env_ptr;
environment_guard.name = self.name;
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one `Lazy` 'G_ENV'
// * one inside the 'Environment' returned
Ok(Environment {
env: G_ENV.clone(),
execution_providers: self.execution_providers
})
} else {
warn!(
name = environment_guard.name.as_str(),
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
"Environment already initialized for this thread, reusing it",
);
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one `Lazy` 'G_ENV'
// * one inside the 'Environment' returned
Ok(Environment {
env: G_ENV.clone(),
execution_providers: self.execution_providers.clone()
})
let _ = G_ENV.set(EnvironmentSingleton {
execution_providers: self.execution_providers,
env_ptr: AtomicPtr::new(env_ptr)
});
}
Ok(())
}
}
pub fn init() -> EnvironmentBuilder {
EnvironmentBuilder::default()
}
#[cfg(test)]
mod tests {
use std::sync::{RwLock, RwLockWriteGuard};
use std::sync::{atomic::Ordering, Arc, RwLock, RwLockWriteGuard};
use once_cell::sync::Lazy;
use test_log::test;
@@ -351,11 +201,11 @@ mod tests {
use super::*;
fn is_env_initialized() -> bool {
Arc::strong_count(&G_ENV) >= 2
G_ENV.get().is_some() && !G_ENV.get().unwrap().env_ptr.load(Ordering::Relaxed).is_null()
}
fn env_ptr() -> *const ort_sys::OrtEnv {
*G_ENV.lock().unwrap().env_ptr.get_mut()
fn env_ptr() -> Option<*mut ort_sys::OrtEnv> {
G_ENV.get().map(|f| f.env_ptr.load(Ordering::Relaxed))
}
struct ConcurrentTestRun {
@@ -373,76 +223,14 @@ mod tests {
let _run_lock = single_test_run();
assert!(!is_env_initialized());
assert_eq!(env_ptr(), std::ptr::null_mut());
assert_eq!(env_ptr(), None);
let env = Environment::builder()
EnvironmentBuilder::default()
.with_name("env_is_initialized")
.with_log_level(LoggingLevel::Warning)
.build()
.commit()
.unwrap();
assert!(is_env_initialized());
assert_ne!(env_ptr(), std::ptr::null_mut());
drop(env);
assert!(!is_env_initialized());
assert_eq!(env_ptr(), std::ptr::null_mut());
}
#[ignore]
#[test]
fn sequential_environment_creation() {
let _concurrent_run_lock_guard = single_test_run();
let mut prev_env_ptr = env_ptr();
for i in 0..10 {
let name = format!("sequential_environment_creation: {}", i);
let env = Environment::builder()
.with_name(name.clone())
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let next_env_ptr = env_ptr();
assert_ne!(next_env_ptr, prev_env_ptr);
prev_env_ptr = next_env_ptr;
assert_eq!(env.name(), name);
}
}
#[test]
fn concurrent_environment_creations() {
let _concurrent_run_lock_guard = single_test_run();
let initial_name = String::from("concurrent_environment_creation");
let main_env = Environment::builder()
.with_name(&initial_name)
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let main_env_ptr = main_env.ptr() as usize;
assert_eq!(main_env.name(), initial_name);
assert_eq!(main_env.ptr() as usize, main_env_ptr);
assert!(
(0..10)
.map(|t| {
let initial_name_cloned = initial_name.clone();
std::thread::spawn(move || {
let name = format!("concurrent_environment_creation: {}", t);
let env = Environment::builder()
.with_name(name)
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert_eq!(env.name(), initial_name_cloned);
assert_eq!(env.ptr() as usize, main_env_ptr);
})
})
.map(|child| child.join())
.all(|r| Result::is_ok(&r))
);
assert_ne!(env_ptr(), None);
}
}

View File

@@ -4,7 +4,7 @@ use std::{convert::Infallible, io, path::PathBuf, string};
use thiserror::Error;
use super::{char_p_to_string, ort, tensor::TensorElementDataType};
use super::{char_p_to_string, ortsys, tensor::TensorElementDataType};
/// Type alias for the Result type returned by ORT functions.
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -281,7 +281,7 @@ impl From<OrtStatusWrapper> for Result<(), ErrorInternal> {
if status.0.is_null() {
Ok(())
} else {
let raw: *const std::os::raw::c_char = unsafe { ort().GetErrorMessage.unwrap()(status.0) };
let raw: *const std::os::raw::c_char = ortsys![unsafe GetErrorMessage(status.0)];
match char_p_to_string(raw) {
Ok(msg) => Err(ErrorInternal::Msg(msg)),
Err(err) => match err {
@@ -295,7 +295,7 @@ impl From<OrtStatusWrapper> for Result<(), ErrorInternal> {
impl Drop for OrtStatusWrapper {
fn drop(&mut self) {
unsafe { ort().ReleaseStatus.unwrap()(self.0) }
ortsys![unsafe ReleaseStatus(self.0)];
}
}

View File

@@ -15,6 +15,9 @@ pub struct OpenVINOExecutionProvider {
enable_vpu_fast_compile: bool
}
unsafe impl Send for OpenVINOExecutionProvider {}
unsafe impl Sync for OpenVINOExecutionProvider {}
impl Default for OpenVINOExecutionProvider {
fn default() -> Self {
Self {

View File

@@ -17,6 +17,9 @@ pub struct ROCmExecutionProvider {
tunable_op_max_tuning_duration_ms: i32
}
unsafe impl Send for ROCmExecutionProvider {}
unsafe impl Sync for ROCmExecutionProvider {}
impl Default for ROCmExecutionProvider {
fn default() -> Self {
Self {

View File

@@ -31,7 +31,7 @@ use std::{
use once_cell::sync::Lazy;
use tracing::warn;
pub use self::environment::{Environment, EnvironmentBuilder};
pub use self::environment::{init, EnvironmentBuilder};
#[cfg(feature = "fetch-models")]
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
pub use self::error::FetchModelError;

View File

@@ -41,10 +41,9 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ndarray::Array1;
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
/// # use ort::{GraphOptimizationLevel, Session};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let environment = Environment::default().into_arc();
/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?;
/// let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
/// let _ = session.run(ort::inputs![Array1::from_vec(vec![1, 2, 3, 4, 5])]?);
/// # Ok(())
/// # }
@@ -55,10 +54,9 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ndarray::Array1;
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
/// # use ort::{GraphOptimizationLevel, Session};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let environment = Environment::default().into_arc();
/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?;
/// let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
/// let _ = session.run(ort::inputs! {
/// "tokens" => Array1::from_vec(vec![1, 2, 3, 4, 5])
/// }?);

View File

@@ -6,11 +6,20 @@ use std::os::unix::ffi::OsStrExt;
use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "fetch-models")]
use std::{env, path::PathBuf, time::Duration};
use std::{ffi::CString, fmt, marker::PhantomData, ops::Deref, os::raw::c_char, path::Path, ptr, sync::Arc};
use std::{
ffi::CString,
fmt,
marker::PhantomData,
ops::Deref,
os::raw::c_char,
path::Path,
ptr,
sync::{atomic::Ordering, Arc}
};
use super::{
char_p_to_string,
environment::Environment,
environment::get_environment,
error::{assert_non_null_pointer, assert_null_pointer, status_to_result, Error, ErrorInternal, Result},
execution_providers::{apply_execution_providers, ExecutionProviderDispatch},
extern_system_fn,
@@ -28,8 +37,8 @@ pub(crate) mod input;
pub(crate) mod output;
pub use self::{input::SessionInputs, output::SessionOutputs};
/// Type used to create a session using the _builder pattern_. Once created, you can use the different methods to
/// configure the session.
/// Type used to create a session using the _builder pattern_. Once created with [`Session::builder`], you can use the
/// different methods to configure the session.
///
/// Once configured, use the [`SessionBuilder::with_model_from_file`](crate::SessionBuilder::with_model_from_file)
/// method to "commit" the builder configuration into a [`Session`].
@@ -38,14 +47,9 @@ pub use self::{input::SessionInputs, output::SessionOutputs};
///
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
/// # use ort::{GraphOptimizationLevel, Session};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// let environment = Environment::builder()
/// .with_name("test")
/// .with_log_level(LoggingLevel::Verbose)
/// .build()?
/// .into_arc();
/// let mut session = SessionBuilder::new(&environment)?
/// let mut session = Session::builder()?
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
/// .with_intra_threads(1)?
/// .with_model_from_file("squeezenet.onnx")?;
@@ -59,16 +63,12 @@ pub struct SessionBuilder {
memory_type: MemType,
#[cfg(feature = "custom-ops")]
custom_runtime_handles: Vec<*mut std::os::raw::c_void>,
execution_providers: Vec<ExecutionProviderDispatch>,
// env must be last to drop it after everything else
env: Arc<Environment>
execution_providers: Vec<ExecutionProviderDispatch>
}
impl fmt::Debug for SessionBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("SessionBuilder")
.field("env", &self.env.name())
.field("allocator", &self.allocator)
.field("memory_type", &self.memory_type)
.finish()
@@ -81,7 +81,6 @@ impl Clone for SessionBuilder {
status_to_result(ortsys![unsafe CloneSessionOptions(self.session_options_ptr, &mut session_options_ptr as *mut _)])
.expect("error cloning session options");
Self {
env: Arc::clone(&self.env),
session_options_ptr,
allocator: self.allocator,
memory_type: self.memory_type,
@@ -108,12 +107,11 @@ impl Drop for SessionBuilder {
impl SessionBuilder {
/// Creates a new session builder in the given environment.
pub fn new(env: &Arc<Environment>) -> Result<Self> {
pub fn new() -> Result<Self> {
let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = std::ptr::null_mut();
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> Error::CreateSessionOptions; nonNull(session_options_ptr)];
Ok(Self {
env: Arc::clone(env),
session_options_ptr,
allocator: AllocatorType::Device,
memory_type: MemType::Default,
@@ -384,9 +382,10 @@ impl SessionBuilder {
.map(|b| *b as std::os::raw::c_char)
.collect();
apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned());
let env = get_environment()?;
apply_execution_providers(&self, self.execution_providers.iter().chain(&env.execution_providers).cloned());
let env_ptr: *const ort_sys::OrtEnv = self.env.ptr();
let env_ptr = env.env_ptr.load(Ordering::Relaxed);
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
@@ -404,90 +403,7 @@ impl SessionBuilder {
.collect::<Result<Vec<Output>>>()?;
Ok(Session {
inner: Arc::new(SharedSessionInner {
env: Arc::clone(&self.env),
session_ptr,
allocator
}),
inputs,
outputs
})
}
/// Loads an ONNX model from a file, replacing external data with data provided in initializers.
///
/// This will find initialized tensors with external data in the graph with the provided names and replace them with
/// the provided tensors. The replacement will occur before any optimizations take place, and the data will be
/// copied into the graph. Tensors replaced by this function must be using external data. (you cannot replace a
/// non-external tensor)
pub fn with_model_from_file_and_external_initializers<'v, 'i, P>(self, model_filepath_ref: P, initializers: &'i [(String, Value)]) -> Result<Session>
where
'i: 'v,
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(Error::FileDoesNotExist {
filename: model_filepath.to_path_buf()
});
}
// Build an OsString, then a vector of bytes to pass to C
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) // Make sure we have a null terminated string
.collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
.map(|b| *b as std::os::raw::c_char)
.collect();
apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned());
let env_ptr: *const ort_sys::OrtEnv = self.env.ptr();
let allocator = Allocator::default();
let initializer_names: Vec<CString> = initializers
.iter()
.map(|(name, _)| CString::new(name.as_str()).unwrap())
.map(|n| CString::new(n).unwrap())
.collect();
let initializer_names_ptr: Vec<*const c_char> = initializer_names.iter().map(|n| n.as_ptr() as *const c_char).collect();
let initializers: Vec<*const ort_sys::OrtValue> = initializers.iter().map(|input_array_ort| input_array_ort.1.ptr() as *const _).collect();
if !initializers.is_empty() {
assert_eq!(initializer_names.len(), initializers.len());
ortsys![unsafe AddExternalInitializers(self.session_options_ptr, initializer_names_ptr.as_ptr(), initializers.as_ptr(), initializers.len() as _) -> Error::CreateSession];
}
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
std::mem::drop(initializer_names);
std::mem::drop(initializers);
// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator.ptr, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator.ptr, i))
.collect::<Result<Vec<Output>>>()?;
Ok(Session {
inner: Arc::new(SharedSessionInner {
env: Arc::clone(&self.env),
session_ptr,
allocator
}),
inner: Arc::new(SharedSessionInner { session_ptr, allocator }),
inputs,
outputs
})
@@ -496,8 +412,9 @@ impl SessionBuilder {
/// Load an ONNX graph from memory and commit the session
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`.
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array).
/// If you want to store the model file and the [`InMemorySession`] in same struct,
/// please check crates for creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
///
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
pub fn with_model_from_memory_directly(self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
let str_to_char = |s: &str| {
s.as_bytes()
@@ -519,9 +436,10 @@ impl SessionBuilder {
pub fn with_model_from_memory(self, model_bytes: &[u8]) -> Result<Session> {
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
let env_ptr: *const ort_sys::OrtEnv = self.env.ptr();
let env = get_environment()?;
apply_execution_providers(&self, self.execution_providers.iter().chain(&env.execution_providers).cloned());
apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned());
let env_ptr = env.env_ptr.load(Ordering::Relaxed);
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len();
@@ -543,11 +461,7 @@ impl SessionBuilder {
.collect::<Result<Vec<Output>>>()?;
let session = Session {
inner: Arc::new(SharedSessionInner {
env: Arc::clone(&self.env),
session_ptr,
allocator
}),
inner: Arc::new(SharedSessionInner { session_ptr, allocator }),
inputs,
outputs
};
@@ -560,11 +474,7 @@ impl SessionBuilder {
#[derive(Debug)]
pub struct SharedSessionInner {
pub(crate) session_ptr: *mut ort_sys::OrtSession,
allocator: Allocator,
// hold onto an environment arc to ensure the environment also stays alive
// env must be last to drop it after everything else
#[allow(dead_code)]
env: Arc<Environment>
allocator: Allocator
}
unsafe impl Send for SharedSessionInner {}
@@ -624,6 +534,10 @@ pub struct Output {
}
impl Session {
pub fn builder() -> Result<SessionBuilder> {
SessionBuilder::new()
}
/// Returns this session's [`Allocator`].
pub fn allocator(&self) -> &Allocator {
&self.inner.allocator

View File

@@ -330,9 +330,7 @@ impl Value {
})
}
/// Construct a [`Value`] from a Rust-owned [`CowArray`].
///
/// `allocator` is required to be `Some` when converting a String tensor. See [`crate::Session::allocator`].
/// Construct a [`Value`] from a Rust-owned array.
pub fn from_string_array<T: Utf8Data + Debug + Clone + 'static>(allocator: &Allocator, input: impl OrtInput<Item = T>) -> Result<Value> {
let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemType::Default)?;

View File

@@ -1,22 +1,16 @@
use std::path::Path;
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel};
use ort::{
download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor
};
use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
use test_log::test;
#[test]
fn mnist_5() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
let session = SessionBuilder::new(&environment)?
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(DomainBasedImageClassification::Mnist)

View File

@@ -7,22 +7,16 @@ use std::{
use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
use ndarray::s;
use ort::{
download::vision::ImageClassification, inputs, ArrayExtensions, Environment, FetchModelError, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor
};
use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
use test_log::test;
#[test]
fn squeezenet_mushroom() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
let session = SessionBuilder::new(&environment)?
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(ImageClassification::SqueezeNet)

View File

@@ -2,7 +2,7 @@ use std::path::Path;
use image::RgbImage;
use ndarray::{Array, CowArray, Ix4};
use ort::{inputs, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor};
use ort::{inputs, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
use test_log::test;
fn load_input_image<P: AsRef<Path>>(name: P) -> RgbImage {
@@ -44,15 +44,11 @@ fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, Ix4> {
fn upsample() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
let session = SessionBuilder::new(&environment)?
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_memory(&session_data)
@@ -89,15 +85,11 @@ fn upsample() -> ort::Result<()> {
fn upsample_with_ort_model() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
.into_arc();
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file");
let session = SessionBuilder::new(&environment)?
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_memory_directly(&session_data) // Zero-copy.

View File

@@ -1,15 +1,13 @@
use std::path::Path;
use ndarray::{ArrayD, IxDyn};
use ort::{inputs, Environment, GraphOptimizationLevel, SessionBuilder, Value};
use ort::{inputs, GraphOptimizationLevel, Session, Value};
use test_log::test;
#[test]
#[cfg(not(target_arch = "aarch64"))]
fn vectorizer() -> ort::Result<()> {
let environment = Environment::default().into_arc();
let session = SessionBuilder::new(&environment)?
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))