mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: create environment in global OnceLock
This commit is contained in:
@@ -4,7 +4,7 @@ use std::{
|
||||
};
|
||||
|
||||
use ndarray::{array, concatenate, s, Array1, Axis};
|
||||
use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, Environment, GraphOptimizationLevel, SessionBuilder, Tensor};
|
||||
use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Tensor};
|
||||
use rand::Rng;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@@ -23,17 +23,17 @@ fn main() -> ort::Result<()> {
|
||||
// Initialize tracing to receive debug messages from `ort`
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
|
||||
ort::init()
|
||||
.with_name("GPT-2")
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.commit()?;
|
||||
|
||||
let mut stdout = io::stdout();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Create the ONNX Runtime environment and session for the GPT-2 model.
|
||||
let environment = Environment::builder()
|
||||
.with_name("GPT-2")
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.build()?
|
||||
.into_arc();
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
// Load our model
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded(GPT2::GPT2LmHead)?;
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::path::Path;
|
||||
|
||||
use image::{imageops::FilterType, GenericImageView};
|
||||
use ndarray::{s, Array, Axis};
|
||||
use ort::{inputs, CUDAExecutionProvider, Environment, SessionBuilder, SessionOutputs};
|
||||
use ort::{inputs, CUDAExecutionProvider, Session, SessionOutputs};
|
||||
use raqote::{DrawOptions, DrawTarget, LineJoin, PathBuilder, SolidSource, Source, StrokeStyle};
|
||||
use show_image::{event, AsImageView, WindowOptions};
|
||||
|
||||
@@ -42,6 +42,10 @@ const YOLOV8_CLASS_LABELS: [&str; 80] = [
|
||||
fn main() -> ort::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
ort::init()
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.commit()?;
|
||||
|
||||
let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("baseball.jpg")).unwrap();
|
||||
let (img_width, img_height) = (original_img.width(), original_img.height());
|
||||
let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
|
||||
@@ -55,14 +59,10 @@ fn main() -> ort::Result<()> {
|
||||
input[[0, 2, y, x]] = (b as f32) / 255.;
|
||||
}
|
||||
|
||||
let env = Environment::builder()
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.build()?
|
||||
.into_arc();
|
||||
let model = SessionBuilder::new(&env).unwrap().with_model_downloaded(YOLOV8M_URL).unwrap();
|
||||
let model = Session::builder()?.with_model_downloaded(YOLOV8M_URL)?;
|
||||
|
||||
// Run YOLOv8 inference
|
||||
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?).unwrap();
|
||||
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?;
|
||||
let output = outputs["output0"].extract_tensor::<f32>().unwrap().view().t().into_owned();
|
||||
|
||||
let mut boxes = Vec::new();
|
||||
|
||||
@@ -1,28 +1,31 @@
|
||||
use std::{
|
||||
ffi::CString,
|
||||
sync::{atomic::AtomicPtr, Arc, Mutex}
|
||||
sync::{atomic::AtomicPtr, OnceLock}
|
||||
};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use tracing::{debug, error, warn};
|
||||
use tracing::debug;
|
||||
|
||||
use super::{
|
||||
custom_logger,
|
||||
error::{status_to_result, Error, Result},
|
||||
ort, ortsys, ExecutionProviderDispatch, LoggingLevel
|
||||
error::{Error, Result},
|
||||
ortsys, ExecutionProviderDispatch, LoggingLevel
|
||||
};
|
||||
|
||||
static G_ENV: Lazy<Arc<Mutex<EnvironmentSingleton>>> = Lazy::new(|| {
|
||||
Arc::new(Mutex::new(EnvironmentSingleton {
|
||||
name: String::from("uninitialized"),
|
||||
env_ptr: AtomicPtr::new(std::ptr::null_mut())
|
||||
}))
|
||||
});
|
||||
static G_ENV: OnceLock<EnvironmentSingleton> = OnceLock::new();
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EnvironmentSingleton {
|
||||
name: String,
|
||||
env_ptr: AtomicPtr<ort_sys::OrtEnv>
|
||||
pub(crate) struct EnvironmentSingleton {
|
||||
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
|
||||
pub(crate) env_ptr: AtomicPtr<ort_sys::OrtEnv>
|
||||
}
|
||||
|
||||
pub(crate) fn get_environment() -> Result<&'static EnvironmentSingleton> {
|
||||
if G_ENV.get().is_none() {
|
||||
EnvironmentBuilder::default().commit()?;
|
||||
Ok(G_ENV.get().unwrap())
|
||||
} else {
|
||||
Ok(unsafe { G_ENV.get().unwrap_unchecked() })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
@@ -33,145 +36,6 @@ pub struct EnvironmentGlobalThreadPoolOptions {
|
||||
pub intra_op_thread_affinity: Option<String>
|
||||
}
|
||||
|
||||
/// An [`Environment`] is the main entry point of the ONNX Runtime.
|
||||
///
|
||||
/// Only one ONNX environment can be created per process. A singleton is used to enforce this.
|
||||
///
|
||||
/// Once an environment is created, a [`super::Session`] can be obtained from it.
|
||||
///
|
||||
/// **NOTE**: While the [`Environment`] constructor takes a `name` parameter to name the environment, only the first
|
||||
/// name will be considered if many environments are created.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::error::Error;
|
||||
/// # use ort::{Environment, LoggingLevel};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// let environment = Environment::builder().with_name("test").with_log_level(LoggingLevel::Verbose).build()?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Environment {
|
||||
env: Arc<Mutex<EnvironmentSingleton>>,
|
||||
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>
|
||||
}
|
||||
|
||||
unsafe impl Send for Environment {}
|
||||
unsafe impl Sync for Environment {}
|
||||
|
||||
impl Environment {
|
||||
/// Create a new environment builder using default values
|
||||
/// (name: `default`, log level: [`LoggingLevel::Warning`])
|
||||
pub fn builder() -> EnvironmentBuilder {
|
||||
EnvironmentBuilder {
|
||||
name: "default".into(),
|
||||
log_level: LoggingLevel::Warning,
|
||||
execution_providers: Vec::new(),
|
||||
global_thread_pool_options: None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the name of the current environment
|
||||
pub fn name(&self) -> String {
|
||||
self.env.lock().unwrap().name.to_string()
|
||||
}
|
||||
|
||||
/// Wraps this environment in an `Arc` for use with `SessionBuilder`.
|
||||
pub fn into_arc(self) -> Arc<Environment> {
|
||||
Arc::new(self)
|
||||
}
|
||||
|
||||
pub(crate) fn ptr(&self) -> *const ort_sys::OrtEnv {
|
||||
*self.env.lock().unwrap().env_ptr.get_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Environment {
|
||||
fn default() -> Self {
|
||||
// NOTE: Because 'G_ENV' is `Lazy`, locking it will, initially, create
|
||||
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
|
||||
// Cloning it to embed it inside the 'Environment' to return
|
||||
// will thus increase the strong count to 2.
|
||||
let mut environment_guard = G_ENV.lock().expect("Failed to acquire global environment lock: another thread panicked?");
|
||||
let g_env_ptr = environment_guard.env_ptr.get_mut();
|
||||
if g_env_ptr.is_null() {
|
||||
debug!("Environment not yet initialized, creating a new one");
|
||||
|
||||
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
|
||||
|
||||
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
|
||||
// FIXME: What should go here?
|
||||
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
|
||||
let cname = CString::new("default".to_string()).unwrap();
|
||||
|
||||
status_to_result(
|
||||
ortsys![unsafe CreateEnvWithCustomLogger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr); nonNull(env_ptr)]
|
||||
)
|
||||
.map_err(Error::CreateEnvironment)
|
||||
.unwrap();
|
||||
|
||||
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
|
||||
|
||||
*g_env_ptr = env_ptr;
|
||||
environment_guard.name = "default".to_string();
|
||||
|
||||
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
|
||||
// If this 'Environment' is the only one in the process, the strong count
|
||||
// will be 2:
|
||||
// * one `Lazy` 'G_ENV'
|
||||
// * one inside the 'Environment' returned
|
||||
Environment {
|
||||
env: G_ENV.clone(),
|
||||
execution_providers: vec![]
|
||||
}
|
||||
} else {
|
||||
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
|
||||
// If this 'Environment' is the only one in the process, the strong count
|
||||
// will be 2:
|
||||
// * one `Lazy` 'G_ENV'
|
||||
// * one inside the 'Environment' returned
|
||||
Environment {
|
||||
env: G_ENV.clone(),
|
||||
execution_providers: vec![]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Environment {
|
||||
#[tracing::instrument]
|
||||
fn drop(&mut self) {
|
||||
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Dropping environment");
|
||||
|
||||
let mut environment_guard = self.env.lock().expect("Failed to acquire lock: another thread panicked?");
|
||||
|
||||
// NOTE: If we drop an 'Environment' we (obviously) have _at least_
|
||||
// one 'G_ENV' strong count (the one in the 'env' member).
|
||||
// There is also the "original" 'G_ENV' which is a the `Lazy` global.
|
||||
// If there is no other environment, the strong count should be two and we
|
||||
// can properly free the sys::OrtEnv pointer.
|
||||
if Arc::strong_count(&G_ENV) == 2 {
|
||||
let release_env = ort().ReleaseEnv.unwrap();
|
||||
let env_ptr: *mut ort_sys::OrtEnv = *environment_guard.env_ptr.get_mut();
|
||||
|
||||
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Releasing environment");
|
||||
|
||||
assert_ne!(env_ptr, std::ptr::null_mut());
|
||||
if env_ptr.is_null() {
|
||||
error!("Environment pointer is null, not dropping!");
|
||||
} else {
|
||||
unsafe { release_env(env_ptr) };
|
||||
}
|
||||
|
||||
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
|
||||
environment_guard.name = String::from("uninitialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct used to build an environment [`Environment`].
|
||||
///
|
||||
/// This is ONNX Runtime's main entry point. An environment _must_ be created as the first step. An [`Environment`] can
|
||||
@@ -189,6 +53,17 @@ pub struct EnvironmentBuilder {
|
||||
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
|
||||
}
|
||||
|
||||
impl Default for EnvironmentBuilder {
|
||||
fn default() -> Self {
|
||||
EnvironmentBuilder {
|
||||
name: "default".to_string(),
|
||||
log_level: LoggingLevel::Error,
|
||||
execution_providers: vec![],
|
||||
global_thread_pool_options: None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EnvironmentBuilder {
|
||||
/// Configure the environment with a given name
|
||||
///
|
||||
@@ -263,14 +138,8 @@ impl EnvironmentBuilder {
|
||||
}
|
||||
|
||||
/// Commit the configuration to a new [`Environment`].
|
||||
pub fn build(self) -> Result<Environment> {
|
||||
// NOTE: Because 'G_ENV' is a `Lazy`, locking it will, initially, create
|
||||
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
|
||||
// Cloning it to embed it inside the 'Environment' to return
|
||||
// will thus increase the strong count to 2.
|
||||
let mut environment_guard = G_ENV.lock().expect("Failed to acquire global environment lock: another thread panicked?");
|
||||
let g_env_ptr = environment_guard.env_ptr.get_mut();
|
||||
if g_env_ptr.is_null() {
|
||||
pub fn commit(self) -> Result<()> {
|
||||
if G_ENV.get().is_none() {
|
||||
debug!("Environment not yet initialized, creating a new one");
|
||||
|
||||
let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options {
|
||||
@@ -309,41 +178,22 @@ impl EnvironmentBuilder {
|
||||
};
|
||||
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
|
||||
|
||||
*g_env_ptr = env_ptr;
|
||||
environment_guard.name = self.name;
|
||||
|
||||
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
|
||||
// If this 'Environment' is the only one in the process, the strong count
|
||||
// will be 2:
|
||||
// * one `Lazy` 'G_ENV'
|
||||
// * one inside the 'Environment' returned
|
||||
Ok(Environment {
|
||||
env: G_ENV.clone(),
|
||||
execution_providers: self.execution_providers
|
||||
})
|
||||
} else {
|
||||
warn!(
|
||||
name = environment_guard.name.as_str(),
|
||||
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
|
||||
"Environment already initialized for this thread, reusing it",
|
||||
);
|
||||
|
||||
// NOTE: Cloning the `Lazy` 'G_ENV' will increase its strong count by one.
|
||||
// If this 'Environment' is the only one in the process, the strong count
|
||||
// will be 2:
|
||||
// * one `Lazy` 'G_ENV'
|
||||
// * one inside the 'Environment' returned
|
||||
Ok(Environment {
|
||||
env: G_ENV.clone(),
|
||||
execution_providers: self.execution_providers.clone()
|
||||
})
|
||||
let _ = G_ENV.set(EnvironmentSingleton {
|
||||
execution_providers: self.execution_providers,
|
||||
env_ptr: AtomicPtr::new(env_ptr)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init() -> EnvironmentBuilder {
|
||||
EnvironmentBuilder::default()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{RwLock, RwLockWriteGuard};
|
||||
use std::sync::{atomic::Ordering, Arc, RwLock, RwLockWriteGuard};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use test_log::test;
|
||||
@@ -351,11 +201,11 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
fn is_env_initialized() -> bool {
|
||||
Arc::strong_count(&G_ENV) >= 2
|
||||
G_ENV.get().is_some() && !G_ENV.get().unwrap().env_ptr.load(Ordering::Relaxed).is_null()
|
||||
}
|
||||
|
||||
fn env_ptr() -> *const ort_sys::OrtEnv {
|
||||
*G_ENV.lock().unwrap().env_ptr.get_mut()
|
||||
fn env_ptr() -> Option<*mut ort_sys::OrtEnv> {
|
||||
G_ENV.get().map(|f| f.env_ptr.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
struct ConcurrentTestRun {
|
||||
@@ -373,76 +223,14 @@ mod tests {
|
||||
let _run_lock = single_test_run();
|
||||
|
||||
assert!(!is_env_initialized());
|
||||
assert_eq!(env_ptr(), std::ptr::null_mut());
|
||||
assert_eq!(env_ptr(), None);
|
||||
|
||||
let env = Environment::builder()
|
||||
EnvironmentBuilder::default()
|
||||
.with_name("env_is_initialized")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()
|
||||
.commit()
|
||||
.unwrap();
|
||||
assert!(is_env_initialized());
|
||||
assert_ne!(env_ptr(), std::ptr::null_mut());
|
||||
|
||||
drop(env);
|
||||
assert!(!is_env_initialized());
|
||||
assert_eq!(env_ptr(), std::ptr::null_mut());
|
||||
}
|
||||
|
||||
#[ignore]
|
||||
#[test]
|
||||
fn sequential_environment_creation() {
|
||||
let _concurrent_run_lock_guard = single_test_run();
|
||||
|
||||
let mut prev_env_ptr = env_ptr();
|
||||
|
||||
for i in 0..10 {
|
||||
let name = format!("sequential_environment_creation: {}", i);
|
||||
let env = Environment::builder()
|
||||
.with_name(name.clone())
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()
|
||||
.unwrap();
|
||||
let next_env_ptr = env_ptr();
|
||||
assert_ne!(next_env_ptr, prev_env_ptr);
|
||||
prev_env_ptr = next_env_ptr;
|
||||
|
||||
assert_eq!(env.name(), name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrent_environment_creations() {
|
||||
let _concurrent_run_lock_guard = single_test_run();
|
||||
|
||||
let initial_name = String::from("concurrent_environment_creation");
|
||||
let main_env = Environment::builder()
|
||||
.with_name(&initial_name)
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()
|
||||
.unwrap();
|
||||
let main_env_ptr = main_env.ptr() as usize;
|
||||
|
||||
assert_eq!(main_env.name(), initial_name);
|
||||
assert_eq!(main_env.ptr() as usize, main_env_ptr);
|
||||
|
||||
assert!(
|
||||
(0..10)
|
||||
.map(|t| {
|
||||
let initial_name_cloned = initial_name.clone();
|
||||
std::thread::spawn(move || {
|
||||
let name = format!("concurrent_environment_creation: {}", t);
|
||||
let env = Environment::builder()
|
||||
.with_name(name)
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(env.name(), initial_name_cloned);
|
||||
assert_eq!(env.ptr() as usize, main_env_ptr);
|
||||
})
|
||||
})
|
||||
.map(|child| child.join())
|
||||
.all(|r| Result::is_ok(&r))
|
||||
);
|
||||
assert_ne!(env_ptr(), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{convert::Infallible, io, path::PathBuf, string};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{char_p_to_string, ort, tensor::TensorElementDataType};
|
||||
use super::{char_p_to_string, ortsys, tensor::TensorElementDataType};
|
||||
|
||||
/// Type alias for the Result type returned by ORT functions.
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
@@ -281,7 +281,7 @@ impl From<OrtStatusWrapper> for Result<(), ErrorInternal> {
|
||||
if status.0.is_null() {
|
||||
Ok(())
|
||||
} else {
|
||||
let raw: *const std::os::raw::c_char = unsafe { ort().GetErrorMessage.unwrap()(status.0) };
|
||||
let raw: *const std::os::raw::c_char = ortsys![unsafe GetErrorMessage(status.0)];
|
||||
match char_p_to_string(raw) {
|
||||
Ok(msg) => Err(ErrorInternal::Msg(msg)),
|
||||
Err(err) => match err {
|
||||
@@ -295,7 +295,7 @@ impl From<OrtStatusWrapper> for Result<(), ErrorInternal> {
|
||||
|
||||
impl Drop for OrtStatusWrapper {
|
||||
fn drop(&mut self) {
|
||||
unsafe { ort().ReleaseStatus.unwrap()(self.0) }
|
||||
ortsys![unsafe ReleaseStatus(self.0)];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,9 @@ pub struct OpenVINOExecutionProvider {
|
||||
enable_vpu_fast_compile: bool
|
||||
}
|
||||
|
||||
unsafe impl Send for OpenVINOExecutionProvider {}
|
||||
unsafe impl Sync for OpenVINOExecutionProvider {}
|
||||
|
||||
impl Default for OpenVINOExecutionProvider {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -17,6 +17,9 @@ pub struct ROCmExecutionProvider {
|
||||
tunable_op_max_tuning_duration_ms: i32
|
||||
}
|
||||
|
||||
unsafe impl Send for ROCmExecutionProvider {}
|
||||
unsafe impl Sync for ROCmExecutionProvider {}
|
||||
|
||||
impl Default for ROCmExecutionProvider {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -31,7 +31,7 @@ use std::{
|
||||
use once_cell::sync::Lazy;
|
||||
use tracing::warn;
|
||||
|
||||
pub use self::environment::{Environment, EnvironmentBuilder};
|
||||
pub use self::environment::{init, EnvironmentBuilder};
|
||||
#[cfg(feature = "fetch-models")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
|
||||
pub use self::error::FetchModelError;
|
||||
|
||||
@@ -41,10 +41,9 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
|
||||
/// ```no_run
|
||||
/// # use std::{error::Error, sync::Arc};
|
||||
/// # use ndarray::Array1;
|
||||
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// # let environment = Environment::default().into_arc();
|
||||
/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?;
|
||||
/// let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
/// let _ = session.run(ort::inputs![Array1::from_vec(vec![1, 2, 3, 4, 5])]?);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
@@ -55,10 +54,9 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
|
||||
/// ```no_run
|
||||
/// # use std::{error::Error, sync::Arc};
|
||||
/// # use ndarray::Array1;
|
||||
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// # let environment = Environment::default().into_arc();
|
||||
/// let mut session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?;
|
||||
/// let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
/// let _ = session.run(ort::inputs! {
|
||||
/// "tokens" => Array1::from_vec(vec![1, 2, 3, 4, 5])
|
||||
/// }?);
|
||||
|
||||
@@ -6,11 +6,20 @@ use std::os::unix::ffi::OsStrExt;
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use std::{env, path::PathBuf, time::Duration};
|
||||
use std::{ffi::CString, fmt, marker::PhantomData, ops::Deref, os::raw::c_char, path::Path, ptr, sync::Arc};
|
||||
use std::{
|
||||
ffi::CString,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
ops::Deref,
|
||||
os::raw::c_char,
|
||||
path::Path,
|
||||
ptr,
|
||||
sync::{atomic::Ordering, Arc}
|
||||
};
|
||||
|
||||
use super::{
|
||||
char_p_to_string,
|
||||
environment::Environment,
|
||||
environment::get_environment,
|
||||
error::{assert_non_null_pointer, assert_null_pointer, status_to_result, Error, ErrorInternal, Result},
|
||||
execution_providers::{apply_execution_providers, ExecutionProviderDispatch},
|
||||
extern_system_fn,
|
||||
@@ -28,8 +37,8 @@ pub(crate) mod input;
|
||||
pub(crate) mod output;
|
||||
pub use self::{input::SessionInputs, output::SessionOutputs};
|
||||
|
||||
/// Type used to create a session using the _builder pattern_. Once created, you can use the different methods to
|
||||
/// configure the session.
|
||||
/// Type used to create a session using the _builder pattern_. Once created with [`Session::builder`], you can use the
|
||||
/// different methods to configure the session.
|
||||
///
|
||||
/// Once configured, use the [`SessionBuilder::with_model_from_file`](crate::SessionBuilder::with_model_from_file)
|
||||
/// method to "commit" the builder configuration into a [`Session`].
|
||||
@@ -38,14 +47,9 @@ pub use self::{input::SessionInputs, output::SessionOutputs};
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::{error::Error, sync::Arc};
|
||||
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// let environment = Environment::builder()
|
||||
/// .with_name("test")
|
||||
/// .with_log_level(LoggingLevel::Verbose)
|
||||
/// .build()?
|
||||
/// .into_arc();
|
||||
/// let mut session = SessionBuilder::new(&environment)?
|
||||
/// let mut session = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
/// .with_intra_threads(1)?
|
||||
/// .with_model_from_file("squeezenet.onnx")?;
|
||||
@@ -59,16 +63,12 @@ pub struct SessionBuilder {
|
||||
memory_type: MemType,
|
||||
#[cfg(feature = "custom-ops")]
|
||||
custom_runtime_handles: Vec<*mut std::os::raw::c_void>,
|
||||
execution_providers: Vec<ExecutionProviderDispatch>,
|
||||
|
||||
// env must be last to drop it after everything else
|
||||
env: Arc<Environment>
|
||||
execution_providers: Vec<ExecutionProviderDispatch>
|
||||
}
|
||||
|
||||
impl fmt::Debug for SessionBuilder {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
f.debug_struct("SessionBuilder")
|
||||
.field("env", &self.env.name())
|
||||
.field("allocator", &self.allocator)
|
||||
.field("memory_type", &self.memory_type)
|
||||
.finish()
|
||||
@@ -81,7 +81,6 @@ impl Clone for SessionBuilder {
|
||||
status_to_result(ortsys![unsafe CloneSessionOptions(self.session_options_ptr, &mut session_options_ptr as *mut _)])
|
||||
.expect("error cloning session options");
|
||||
Self {
|
||||
env: Arc::clone(&self.env),
|
||||
session_options_ptr,
|
||||
allocator: self.allocator,
|
||||
memory_type: self.memory_type,
|
||||
@@ -108,12 +107,11 @@ impl Drop for SessionBuilder {
|
||||
|
||||
impl SessionBuilder {
|
||||
/// Creates a new session builder in the given environment.
|
||||
pub fn new(env: &Arc<Environment>) -> Result<Self> {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = std::ptr::null_mut();
|
||||
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> Error::CreateSessionOptions; nonNull(session_options_ptr)];
|
||||
|
||||
Ok(Self {
|
||||
env: Arc::clone(env),
|
||||
session_options_ptr,
|
||||
allocator: AllocatorType::Device,
|
||||
memory_type: MemType::Default,
|
||||
@@ -384,9 +382,10 @@ impl SessionBuilder {
|
||||
.map(|b| *b as std::os::raw::c_char)
|
||||
.collect();
|
||||
|
||||
apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned());
|
||||
let env = get_environment()?;
|
||||
apply_execution_providers(&self, self.execution_providers.iter().chain(&env.execution_providers).cloned());
|
||||
|
||||
let env_ptr: *const ort_sys::OrtEnv = self.env.ptr();
|
||||
let env_ptr = env.env_ptr.load(Ordering::Relaxed);
|
||||
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
|
||||
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
|
||||
@@ -404,90 +403,7 @@ impl SessionBuilder {
|
||||
.collect::<Result<Vec<Output>>>()?;
|
||||
|
||||
Ok(Session {
|
||||
inner: Arc::new(SharedSessionInner {
|
||||
env: Arc::clone(&self.env),
|
||||
session_ptr,
|
||||
allocator
|
||||
}),
|
||||
inputs,
|
||||
outputs
|
||||
})
|
||||
}
|
||||
|
||||
/// Loads an ONNX model from a file, replacing external data with data provided in initializers.
|
||||
///
|
||||
/// This will find initialized tensors with external data in the graph with the provided names and replace them with
|
||||
/// the provided tensors. The replacement will occur before any optimizations take place, and the data will be
|
||||
/// copied into the graph. Tensors replaced by this function must be using external data. (you cannot replace a
|
||||
/// non-external tensor)
|
||||
pub fn with_model_from_file_and_external_initializers<'v, 'i, P>(self, model_filepath_ref: P, initializers: &'i [(String, Value)]) -> Result<Session>
|
||||
where
|
||||
'i: 'v,
|
||||
P: AsRef<Path>
|
||||
{
|
||||
let model_filepath = model_filepath_ref.as_ref();
|
||||
if !model_filepath.exists() {
|
||||
return Err(Error::FileDoesNotExist {
|
||||
filename: model_filepath.to_path_buf()
|
||||
});
|
||||
}
|
||||
|
||||
// Build an OsString, then a vector of bytes to pass to C
|
||||
let model_path = std::ffi::OsString::from(model_filepath);
|
||||
#[cfg(target_family = "windows")]
|
||||
let model_path: Vec<u16> = model_path
|
||||
.encode_wide()
|
||||
.chain(std::iter::once(0)) // Make sure we have a null terminated string
|
||||
.collect();
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
let model_path: Vec<std::os::raw::c_char> = model_path
|
||||
.as_bytes()
|
||||
.iter()
|
||||
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
|
||||
.map(|b| *b as std::os::raw::c_char)
|
||||
.collect();
|
||||
|
||||
apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned());
|
||||
|
||||
let env_ptr: *const ort_sys::OrtEnv = self.env.ptr();
|
||||
|
||||
let allocator = Allocator::default();
|
||||
|
||||
let initializer_names: Vec<CString> = initializers
|
||||
.iter()
|
||||
.map(|(name, _)| CString::new(name.as_str()).unwrap())
|
||||
.map(|n| CString::new(n).unwrap())
|
||||
.collect();
|
||||
let initializer_names_ptr: Vec<*const c_char> = initializer_names.iter().map(|n| n.as_ptr() as *const c_char).collect();
|
||||
|
||||
let initializers: Vec<*const ort_sys::OrtValue> = initializers.iter().map(|input_array_ort| input_array_ort.1.ptr() as *const _).collect();
|
||||
if !initializers.is_empty() {
|
||||
assert_eq!(initializer_names.len(), initializers.len());
|
||||
ortsys![unsafe AddExternalInitializers(self.session_options_ptr, initializer_names_ptr.as_ptr(), initializers.as_ptr(), initializers.len() as _) -> Error::CreateSession];
|
||||
}
|
||||
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
|
||||
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
|
||||
|
||||
std::mem::drop(initializer_names);
|
||||
std::mem::drop(initializers);
|
||||
|
||||
// Extract input and output properties
|
||||
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
|
||||
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
|
||||
let inputs = (0..num_input_nodes)
|
||||
.map(|i| dangerous::extract_input(session_ptr, allocator.ptr, i))
|
||||
.collect::<Result<Vec<Input>>>()?;
|
||||
let outputs = (0..num_output_nodes)
|
||||
.map(|i| dangerous::extract_output(session_ptr, allocator.ptr, i))
|
||||
.collect::<Result<Vec<Output>>>()?;
|
||||
|
||||
Ok(Session {
|
||||
inner: Arc::new(SharedSessionInner {
|
||||
env: Arc::clone(&self.env),
|
||||
session_ptr,
|
||||
allocator
|
||||
}),
|
||||
inner: Arc::new(SharedSessionInner { session_ptr, allocator }),
|
||||
inputs,
|
||||
outputs
|
||||
})
|
||||
@@ -496,8 +412,9 @@ impl SessionBuilder {
|
||||
/// Load an ONNX graph from memory and commit the session
|
||||
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`.
|
||||
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array).
|
||||
/// If you want to store the model file and the [`InMemorySession`] in same struct,
|
||||
/// please check crates for creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
|
||||
///
|
||||
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that
|
||||
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
|
||||
pub fn with_model_from_memory_directly(self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
|
||||
let str_to_char = |s: &str| {
|
||||
s.as_bytes()
|
||||
@@ -519,9 +436,10 @@ impl SessionBuilder {
|
||||
pub fn with_model_from_memory(self, model_bytes: &[u8]) -> Result<Session> {
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
|
||||
|
||||
let env_ptr: *const ort_sys::OrtEnv = self.env.ptr();
|
||||
let env = get_environment()?;
|
||||
apply_execution_providers(&self, self.execution_providers.iter().chain(&env.execution_providers).cloned());
|
||||
|
||||
apply_execution_providers(&self, self.execution_providers.iter().chain(&self.env.execution_providers).cloned());
|
||||
let env_ptr = env.env_ptr.load(Ordering::Relaxed);
|
||||
|
||||
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
|
||||
let model_data_length = model_bytes.len();
|
||||
@@ -543,11 +461,7 @@ impl SessionBuilder {
|
||||
.collect::<Result<Vec<Output>>>()?;
|
||||
|
||||
let session = Session {
|
||||
inner: Arc::new(SharedSessionInner {
|
||||
env: Arc::clone(&self.env),
|
||||
session_ptr,
|
||||
allocator
|
||||
}),
|
||||
inner: Arc::new(SharedSessionInner { session_ptr, allocator }),
|
||||
inputs,
|
||||
outputs
|
||||
};
|
||||
@@ -560,11 +474,7 @@ impl SessionBuilder {
|
||||
#[derive(Debug)]
|
||||
pub struct SharedSessionInner {
|
||||
pub(crate) session_ptr: *mut ort_sys::OrtSession,
|
||||
allocator: Allocator,
|
||||
// hold onto an environment arc to ensure the environment also stays alive
|
||||
// env must be last to drop it after everything else
|
||||
#[allow(dead_code)]
|
||||
env: Arc<Environment>
|
||||
allocator: Allocator
|
||||
}
|
||||
|
||||
unsafe impl Send for SharedSessionInner {}
|
||||
@@ -624,6 +534,10 @@ pub struct Output {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn builder() -> Result<SessionBuilder> {
|
||||
SessionBuilder::new()
|
||||
}
|
||||
|
||||
/// Returns this session's [`Allocator`].
|
||||
pub fn allocator(&self) -> &Allocator {
|
||||
&self.inner.allocator
|
||||
|
||||
@@ -330,9 +330,7 @@ impl Value {
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct a [`Value`] from a Rust-owned [`CowArray`].
|
||||
///
|
||||
/// `allocator` is required to be `Some` when converting a String tensor. See [`crate::Session::allocator`].
|
||||
/// Construct a [`Value`] from a Rust-owned array.
|
||||
pub fn from_string_array<T: Utf8Data + Debug + Clone + 'static>(allocator: &Allocator, input: impl OrtInput<Item = T>) -> Result<Value> {
|
||||
let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemType::Default)?;
|
||||
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
use std::path::Path;
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel};
|
||||
use ort::{
|
||||
download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor
|
||||
};
|
||||
use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn mnist_5() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded(DomainBasedImageClassification::Mnist)
|
||||
|
||||
@@ -7,22 +7,16 @@ use std::{
|
||||
|
||||
use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
|
||||
use ndarray::s;
|
||||
use ort::{
|
||||
download::vision::ImageClassification, inputs, ArrayExtensions, Environment, FetchModelError, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor
|
||||
};
|
||||
use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded(ImageClassification::SqueezeNet)
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::path::Path;
|
||||
|
||||
use image::RgbImage;
|
||||
use ndarray::{Array, CowArray, Ix4};
|
||||
use ort::{inputs, Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder, Tensor};
|
||||
use ort::{inputs, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
fn load_input_image<P: AsRef<Path>>(name: P) -> RgbImage {
|
||||
@@ -44,15 +44,11 @@ fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, Ix4> {
|
||||
fn upsample() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_memory(&session_data)
|
||||
@@ -89,15 +85,11 @@ fn upsample() -> ort::Result<()> {
|
||||
fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
const IMAGE_TO_LOAD: &str = "mushroom.png";
|
||||
|
||||
let environment = Environment::builder()
|
||||
.with_name("integration_test")
|
||||
.with_log_level(LoggingLevel::Warning)
|
||||
.build()?
|
||||
.into_arc();
|
||||
ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file");
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_memory_directly(&session_data) // Zero-copy.
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
use std::path::Path;
|
||||
|
||||
use ndarray::{ArrayD, IxDyn};
|
||||
use ort::{inputs, Environment, GraphOptimizationLevel, SessionBuilder, Value};
|
||||
use ort::{inputs, GraphOptimizationLevel, Session, Value};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
fn vectorizer() -> ort::Result<()> {
|
||||
let environment = Environment::default().into_arc();
|
||||
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))
|
||||
|
||||
Reference in New Issue
Block a user