mirror of
https://github.com/pykeio/ort
synced 2026-04-26 00:44:56 +02:00
424 lines
15 KiB
Rust
424 lines
15 KiB
Rust
//! An [`Environment`] is a process-global structure, under which [`Session`](crate::session::Session)s are created.
|
|
//!
|
|
//! Environments can be configured via [`ort::init`](init):
|
|
//! ```
|
|
//! # use ort::execution_providers::CUDAExecutionProvider;
|
|
//! # fn main() -> ort::Result<()> {
|
|
//! ort::init()
|
|
//! .with_execution_providers([CUDAExecutionProvider::default().build()])
|
|
//! .commit()?;
|
|
//! # Ok(())
|
|
//! # }
|
|
//! ```
|
|
|
|
use std::{
|
|
any::Any,
|
|
ffi::{self, CStr, CString},
|
|
os::raw::c_void,
|
|
ptr::{self, NonNull},
|
|
sync::{Arc, RwLock}
|
|
};
|
|
|
|
use ort_sys::c_char;
|
|
use tracing::{Level, debug};
|
|
|
|
#[cfg(feature = "load-dynamic")]
|
|
use crate::G_ORT_DYLIB_PATH;
|
|
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys};
|
|
|
|
struct EnvironmentSingleton {
|
|
lock: RwLock<Option<Arc<Environment>>>
|
|
}
|
|
|
|
unsafe impl Sync for EnvironmentSingleton {}
|
|
|
|
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(None) };
|
|
|
|
/// An `Environment` is a process-global structure, under which [`Session`](crate::session::Session)s are created.
|
|
///
|
|
/// Environments can be used to [configure global thread pools](EnvironmentBuilder::with_global_thread_pool), in
|
|
/// which all sessions share threads from the environment's pool, and configuring [default execution
|
|
/// providers](EnvironmentBuilder::with_execution_providers) for all sessions. In the context of `ort` specifically,
|
|
/// environments are also used to configure ONNX Runtime to send log messages through the [`tracing`] crate in Rust.
|
|
///
|
|
/// For ease of use, and since sessions require an environment to be created, `ort` will automatically create an
|
|
/// environment if one is not configured via [`init`] (or [`init_from`]). [`init`] can be called at any point in the
|
|
/// program (even after an environment has been automatically created), though every session created before the
|
|
/// re-configuration would need to be re-created in order to use the config from the new environment.
|
|
#[derive(Debug)]
|
|
pub struct Environment {
|
|
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
|
|
ptr: NonNull<ort_sys::OrtEnv>,
|
|
pub(crate) has_global_threadpool: bool,
|
|
_thread_manager: Option<Box<dyn Any>>
|
|
}
|
|
|
|
unsafe impl Send for Environment {}
|
|
unsafe impl Sync for Environment {}
|
|
|
|
impl AsPointer for Environment {
|
|
type Sys = ort_sys::OrtEnv;
|
|
|
|
fn ptr(&self) -> *const Self::Sys {
|
|
self.ptr.as_ptr()
|
|
}
|
|
}
|
|
|
|
impl Drop for Environment {
|
|
fn drop(&mut self) {
|
|
debug!(ptr = ?self.ptr(), "Releasing environment");
|
|
ortsys![unsafe ReleaseEnv(self.ptr_mut())];
|
|
}
|
|
}
|
|
|
|
/// Gets a reference to the global environment, creating one if an environment has not been
|
|
/// [`commit`](EnvironmentBuilder::commit)ted yet.
|
|
pub fn get_environment() -> Result<Arc<Environment>> {
|
|
let env = G_ENV.lock.read().expect("poisoned lock");
|
|
if let Some(env) = env.as_ref() {
|
|
Ok(Arc::clone(env))
|
|
} else {
|
|
// drop our read lock so we dont deadlock when `commit` takes a write lock
|
|
drop(env);
|
|
|
|
debug!("Environment not yet initialized, creating a new one");
|
|
Ok(EnvironmentBuilder::new().commit()?)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct GlobalThreadPoolOptions {
|
|
ptr: *mut ort_sys::OrtThreadingOptions,
|
|
thread_manager: Option<Box<dyn Any>>
|
|
}
|
|
|
|
impl Default for GlobalThreadPoolOptions {
|
|
fn default() -> Self {
|
|
let mut ptr = ptr::null_mut();
|
|
ortsys![unsafe CreateThreadingOptions(&mut ptr)];
|
|
Self { ptr, thread_manager: None }
|
|
}
|
|
}
|
|
|
|
impl GlobalThreadPoolOptions {
|
|
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
|
|
ortsys![unsafe SetGlobalInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
|
|
ortsys![unsafe SetGlobalIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_spin_control(mut self, spin_control: bool) -> Result<Self> {
|
|
ortsys![unsafe SetGlobalSpinControl(self.ptr_mut(), if spin_control { 1 } else { 0 })?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_intra_affinity(mut self, affinity: impl AsRef<str>) -> Result<Self> {
|
|
let affinity = CString::new(affinity.as_ref())?;
|
|
ortsys![unsafe SetGlobalIntraOpThreadAffinity(self.ptr_mut(), affinity.as_ptr())?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_flush_to_zero(mut self) -> Result<Self> {
|
|
ortsys![unsafe SetGlobalDenormalAsZero(self.ptr_mut())?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
|
let mut manager = Box::new(manager);
|
|
ortsys![unsafe SetGlobalCustomThreadCreationOptions(self.ptr_mut(), (&mut *manager as *mut T).cast())?];
|
|
ortsys![unsafe SetGlobalCustomCreateThreadFn(self.ptr_mut(), Some(thread_create::<T>))?];
|
|
ortsys![unsafe SetGlobalCustomJoinThreadFn(self.ptr_mut(), Some(thread_join::<T>))?];
|
|
self.thread_manager = Some(manager as Box<dyn Any>);
|
|
Ok(self)
|
|
}
|
|
}
|
|
|
|
impl AsPointer for GlobalThreadPoolOptions {
|
|
type Sys = ort_sys::OrtThreadingOptions;
|
|
|
|
fn ptr(&self) -> *const Self::Sys {
|
|
self.ptr
|
|
}
|
|
}
|
|
|
|
impl Drop for GlobalThreadPoolOptions {
|
|
fn drop(&mut self) {
|
|
ortsys![unsafe ReleaseThreadingOptions(self.ptr)];
|
|
}
|
|
}
|
|
|
|
pub struct ThreadWorker {
|
|
data: *mut c_void,
|
|
worker: ort_sys::OrtThreadWorkerFn
|
|
}
|
|
|
|
unsafe impl Send for ThreadWorker {}
|
|
|
|
impl ThreadWorker {
|
|
pub fn work(self) {
|
|
unsafe { self.worker.unwrap_unchecked()(self.data) }
|
|
}
|
|
}
|
|
|
|
pub trait ThreadManager {
|
|
type Thread;
|
|
|
|
fn create(&mut self, worker: ThreadWorker) -> crate::Result<Self::Thread>;
|
|
|
|
fn join(thread: Self::Thread) -> crate::Result<()>;
|
|
}
|
|
|
|
pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
|
|
ort_custom_thread_creation_options: *mut c_void,
|
|
ort_thread_worker_fn: ort_sys::OrtThreadWorkerFn,
|
|
ort_worker_fn_param: *mut c_void
|
|
) -> ort_sys::OrtCustomThreadHandle {
|
|
let thread_worker = ThreadWorker {
|
|
data: ort_worker_fn_param,
|
|
worker: ort_thread_worker_fn
|
|
};
|
|
|
|
let res = std::panic::catch_unwind(|| {
|
|
let manager = unsafe { &mut *ort_custom_thread_creation_options.cast::<T>() };
|
|
<T as ThreadManager>::create(manager, thread_worker)
|
|
});
|
|
match res {
|
|
Ok(Ok(thread)) => (Box::leak(Box::new(thread)) as *mut <T as ThreadManager>::Thread)
|
|
.cast_const()
|
|
.cast::<ort_sys::OrtCustomHandleType>(),
|
|
Ok(Err(e)) => {
|
|
tracing::error!("Failed to create thread using manager: {e}");
|
|
ptr::null()
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Thread manager panicked: {e:?}");
|
|
ptr::null()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) unsafe extern "system" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
|
|
let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<<T as ThreadManager>::Thread>());
|
|
if let Err(e) = <T as ThreadManager>::join(*handle) {
|
|
tracing::error!("Failed to join thread using manager: {e}");
|
|
}
|
|
}
|
|
|
|
/// Struct used to build an [`Environment`]; see [`crate::init`].
|
|
pub struct EnvironmentBuilder {
|
|
name: String,
|
|
telemetry: bool,
|
|
execution_providers: Vec<ExecutionProviderDispatch>,
|
|
global_thread_pool_options: Option<GlobalThreadPoolOptions>
|
|
}
|
|
|
|
impl EnvironmentBuilder {
|
|
pub(crate) fn new() -> Self {
|
|
EnvironmentBuilder {
|
|
name: "default".to_string(),
|
|
telemetry: true,
|
|
execution_providers: vec![],
|
|
global_thread_pool_options: None
|
|
}
|
|
}
|
|
|
|
/// Configure the environment with a given name for logging purposes.
|
|
#[must_use = "commit() must be called in order for the environment to take effect"]
|
|
pub fn with_name<S>(mut self, name: S) -> Self
|
|
where
|
|
S: Into<String>
|
|
{
|
|
self.name = name.into();
|
|
self
|
|
}
|
|
|
|
/// Enable or disable sending telemetry events to Microsoft.
|
|
///
|
|
/// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled.
|
|
/// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled.
|
|
///
|
|
/// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.20.1/onnxruntime/core/platform/windows/telemetry.cc).
|
|
/// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or
|
|
/// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names,
|
|
/// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to
|
|
/// better understand how customers use ONNX Runtime and where performance can be improved.
|
|
#[must_use = "commit() must be called in order for the environment to take effect"]
|
|
pub fn with_telemetry(mut self, enable: bool) -> Self {
|
|
self.telemetry = enable;
|
|
self
|
|
}
|
|
|
|
/// Sets a list of execution providers which all sessions created in this environment will register.
|
|
///
|
|
/// If a session is created in this environment with [`SessionBuilder::with_execution_providers`], those EPs
|
|
/// will take precedence over the environment's EPs.
|
|
///
|
|
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
|
|
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
|
|
/// feature enabled will emit a warning.
|
|
///
|
|
/// [`SessionBuilder::with_execution_providers`]: crate::session::builder::SessionBuilder::with_execution_providers
|
|
#[must_use = "commit() must be called in order for the environment to take effect"]
|
|
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self {
|
|
self.execution_providers = execution_providers.as_ref().to_vec();
|
|
self
|
|
}
|
|
|
|
/// Enables the global thread pool for this environment.
|
|
#[must_use = "commit() must be called in order for the environment to take effect"]
|
|
pub fn with_global_thread_pool(mut self, options: GlobalThreadPoolOptions) -> Self {
|
|
self.global_thread_pool_options = Some(options);
|
|
self
|
|
}
|
|
|
|
/// Commit the environment configuration and set the global environment.
|
|
pub fn commit(self) -> Result<Arc<Environment>> {
|
|
let (env_ptr, thread_manager, has_global_threadpool) = if let Some(mut thread_pool_options) = self.global_thread_pool_options {
|
|
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
|
|
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
|
|
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
|
|
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
|
|
|
|
ortsys![
|
|
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
|
|
logging_function,
|
|
logger_param,
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
|
cname.as_ptr(),
|
|
thread_pool_options.ptr(),
|
|
&mut env_ptr
|
|
)?;
|
|
nonNull(env_ptr)
|
|
];
|
|
|
|
let thread_manager = thread_pool_options.thread_manager.take();
|
|
(env_ptr, thread_manager, true)
|
|
} else {
|
|
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
|
|
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
|
|
// FIXME: What should go here?
|
|
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
|
|
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
|
|
ortsys![
|
|
unsafe CreateEnvWithCustomLogger(
|
|
logging_function,
|
|
logger_param,
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
|
cname.as_ptr(),
|
|
&mut env_ptr
|
|
)?;
|
|
nonNull(env_ptr)
|
|
];
|
|
(env_ptr, None, false)
|
|
};
|
|
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");
|
|
|
|
if self.telemetry {
|
|
ortsys![unsafe EnableTelemetryEvents(env_ptr)?];
|
|
} else {
|
|
ortsys![unsafe DisableTelemetryEvents(env_ptr)?];
|
|
}
|
|
|
|
let mut env_lock = G_ENV.lock.write().expect("poisoned lock");
|
|
// drop global reference to previous environment
|
|
if let Some(env_arc) = env_lock.take() {
|
|
drop(env_arc);
|
|
}
|
|
let env = Arc::new(Environment {
|
|
execution_providers: self.execution_providers,
|
|
// we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
|
|
ptr: unsafe { NonNull::new_unchecked(env_ptr) },
|
|
has_global_threadpool,
|
|
_thread_manager: thread_manager
|
|
});
|
|
env_lock.replace(Arc::clone(&env));
|
|
|
|
Ok(env)
|
|
}
|
|
}
|
|
|
|
/// Creates an ONNX Runtime environment.
|
|
///
|
|
/// ```
|
|
/// # use ort::execution_providers::CUDAExecutionProvider;
|
|
/// # fn main() -> ort::Result<()> {
|
|
/// ort::init()
|
|
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
|
|
/// .commit()?;
|
|
/// # Ok(())
|
|
/// # }
|
|
/// ```
|
|
///
|
|
/// # Notes
|
|
/// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a
|
|
/// default environment will be created.
|
|
/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it.
|
|
/// - In order for environment settings to apply, this must be called **before** you use other APIs like [`Session`],
|
|
/// and you *must* call `.commit()` on the builder returned by this function.
|
|
///
|
|
/// [`Session`]: crate::session::Session
|
|
#[must_use = "commit() must be called in order for the environment to take effect"]
|
|
pub fn init() -> EnvironmentBuilder {
|
|
EnvironmentBuilder::new()
|
|
}
|
|
|
|
/// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`)
|
|
/// specified by `path`.
|
|
///
|
|
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
|
|
///
|
|
/// ```no_run
|
|
/// # use ort::execution_providers::CUDAExecutionProvider;
|
|
/// # fn main() -> ort::Result<()> {
|
|
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
|
|
/// ort::init_from(lib_path.join("onnxruntime.dll"))
|
|
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
|
|
/// .commit()?;
|
|
/// # Ok(())
|
|
/// # }
|
|
/// ```
|
|
///
|
|
/// # Notes
|
|
/// - In order for environment settings to apply, this must be called **before** you use other APIs like [`Session`],
|
|
/// and you *must* call `.commit()` on the builder returned by this function.
|
|
///
|
|
/// [`Session`]: crate::session::Session
|
|
#[cfg(feature = "load-dynamic")]
|
|
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
|
|
#[must_use = "commit() must be called in order for the environment to take effect"]
|
|
pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
|
|
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
|
|
EnvironmentBuilder::new()
|
|
}
|
|
|
|
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
|
|
pub(crate) extern "system" fn custom_logger(
|
|
_params: *mut ffi::c_void,
|
|
severity: ort_sys::OrtLoggingLevel,
|
|
_: *const c_char,
|
|
id: *const c_char,
|
|
code_location: *const c_char,
|
|
message: *const c_char
|
|
) {
|
|
assert_ne!(code_location, ptr::null());
|
|
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
|
|
assert_ne!(message, ptr::null());
|
|
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
|
|
assert_ne!(id, ptr::null());
|
|
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");
|
|
|
|
let span = tracing::span!(Level::TRACE, "ort", id = id, location = code_location);
|
|
|
|
match severity {
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::DEBUG, "{message}"),
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::INFO, "{message}"),
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::WARN, "{message}"),
|
|
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, Level::ERROR, "{message}")
|
|
}
|
|
}
|