mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
355 lines
16 KiB
Rust
355 lines
16 KiB
Rust
use alloc::{rc::Rc, sync::Arc};
|
|
use core::{any::Any, ffi::c_void, ptr};
|
|
#[cfg(feature = "std")]
|
|
use std::{borrow::Cow, path::Path};
|
|
|
|
use super::SessionBuilder;
|
|
#[cfg(feature = "std")]
|
|
use crate::util::path_to_os_char;
|
|
use crate::{
|
|
AsPointer,
|
|
environment::{self, ThreadManager},
|
|
error::Result,
|
|
execution_providers::{ExecutionProviderDispatch, apply_execution_providers},
|
|
memory::MemoryInfo,
|
|
operator::OperatorDomain,
|
|
ortsys,
|
|
util::with_cstr,
|
|
value::DynValue
|
|
};
|
|
|
|
impl SessionBuilder {
|
|
/// Registers a list of execution providers for this session. Execution providers are registered in the order they
|
|
/// are provided.
|
|
///
|
|
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
|
|
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
|
|
/// feature enabled will emit a warning.
|
|
///
|
|
/// ## Notes
|
|
///
|
|
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
|
|
/// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by
|
|
/// providing a `Vec` of [`ExecutionProviderDispatch`]es.
|
|
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result<Self> {
|
|
apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?;
|
|
Ok(self)
|
|
}
|
|
|
|
/// Configure the session to use a number of threads to parallelize the execution within nodes. If ONNX Runtime was
|
|
/// built with OpenMP (as is the case with Microsoft's prebuilt binaries), this will have no effect on the number of
|
|
/// threads used. Instead, you can configure the number of threads OpenMP uses via the `OMP_NUM_THREADS` environment
|
|
/// variable.
|
|
///
|
|
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
|
|
/// [`SessionBuilder::with_inter_threads()`].
|
|
pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
|
|
ortsys![unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// Configure the session to use a number of threads to parallelize the execution of the graph. If nodes can be run
|
|
/// in parallel, this sets the maximum number of threads to use to run them in parallel.
|
|
///
|
|
/// This has no effect when the session execution mode is set to `Sequential`.
|
|
///
|
|
/// For configuring the number of threads used to parallelize the execution within nodes, see
|
|
/// [`SessionBuilder::with_intra_threads()`].
|
|
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
|
|
ortsys![unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// Enable/disable the parallel execution mode for this session. By default, this is disabled.
|
|
///
|
|
/// Parallel execution can improve performance for models with many branches, at the cost of higher memory usage.
|
|
/// You can configure the amount of threads used to parallelize the execution of the graph via
|
|
/// [`SessionBuilder::with_inter_threads()`].
|
|
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> Result<Self> {
|
|
let execution_mode = if parallel_execution {
|
|
ort_sys::ExecutionMode::ORT_PARALLEL
|
|
} else {
|
|
ort_sys::ExecutionMode::ORT_SEQUENTIAL
|
|
};
|
|
ortsys![unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode)?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// Set the session's optimization level. See [`GraphOptimizationLevel`] for more information on the different
|
|
/// optimization levels.
|
|
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Result<Self> {
|
|
ortsys![unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into())?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// After performing optimization (configurable with [`SessionBuilder::with_optimization_level`]), serializes the
|
|
/// newly optimized model to the given path (for 'offline' graph optimization).
|
|
///
|
|
/// Note that the file will only be created after the model is committed.
|
|
#[cfg(feature = "std")]
|
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
|
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> Result<Self> {
|
|
let path = crate::util::path_to_os_char(path);
|
|
ortsys![unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr())?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
|
|
/// See [`Session::end_profiling`].
|
|
///
|
|
/// [`Session::end_profiling`]: crate::session::Session::end_profiling
|
|
#[cfg(feature = "std")]
|
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
|
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> Result<Self> {
|
|
let profiling_file = crate::util::path_to_os_char(profiling_file);
|
|
ortsys![unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr())?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// Enables/disables memory pattern optimization. Disable it if the input size varies, i.e., dynamic batch
|
|
pub fn with_memory_pattern(mut self, enable: bool) -> Result<Self> {
|
|
if enable {
|
|
ortsys![unsafe EnableMemPattern(self.ptr_mut())?];
|
|
} else {
|
|
ortsys![unsafe DisableMemPattern(self.ptr_mut())?];
|
|
}
|
|
Ok(self)
|
|
}
|
|
|
|
/// Set the session's allocator options from a [`MemoryInfo`].
|
|
///
|
|
/// If not provided, the session is created using ONNX Runtime's default device allocator.
|
|
pub fn with_allocator(mut self, info: MemoryInfo) -> Result<Self> {
|
|
self.memory_info = Some(Rc::new(info));
|
|
Ok(self)
|
|
}
|
|
|
|
/// Registers a custom operator library at the given library path.
|
|
#[cfg(feature = "std")]
|
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
|
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> Result<Self> {
|
|
let path_cstr = path_to_os_char(lib_path);
|
|
ortsys![unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr())?];
|
|
Ok(self)
|
|
}
|
|
|
|
/// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators.
|
|
pub fn with_extensions(mut self) -> Result<Self> {
|
|
ortsys![unsafe EnableOrtCustomOps(self.ptr_mut())?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> Result<Self> {
|
|
let domain = domain.into();
|
|
ortsys![unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut())?];
|
|
self.operator_domains.push(domain);
|
|
Ok(self)
|
|
}
|
|
|
|
/// Enables/disables deterministic computation.
|
|
///
|
|
/// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance.
|
|
/// Enabling deterministic compute will output reproducible results, but may come at a performance penalty.
|
|
pub fn with_deterministic_compute(mut self, enable: bool) -> Result<Self> {
|
|
ortsys![unsafe SetDeterministicCompute(self.ptr_mut(), enable)?];
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: DynValue) -> Result<Self> {
|
|
let ptr = self.ptr_mut();
|
|
let value = Rc::new(value);
|
|
with_cstr(name.as_ref().as_bytes(), &|name| {
|
|
ortsys![unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1)?];
|
|
Ok(())
|
|
})?;
|
|
self.external_initializers.push(value);
|
|
Ok(self)
|
|
}
|
|
|
|
#[cfg(feature = "std")]
|
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
|
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: Cow<'static, [u8]>) -> Result<Self> {
|
|
let file_name = path_to_os_char(file_name);
|
|
let sizes = [buffer.len()];
|
|
ortsys![unsafe AddExternalInitializersFromMemory(self.ptr_mut(), &file_name.as_ptr(), &buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(), sizes.as_ptr(), 1)?];
|
|
self.external_initializer_buffers.push(buffer);
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_log_id(mut self, id: impl AsRef<str>) -> Result<Self> {
|
|
let ptr = self.ptr_mut();
|
|
with_cstr(id.as_ref().as_bytes(), &|id| {
|
|
ortsys![unsafe SetSessionLogId(ptr, id.as_ptr())?];
|
|
Ok(())
|
|
})?;
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> Result<Self> {
|
|
let ptr = self.ptr_mut();
|
|
with_cstr(name.as_ref().as_bytes(), &|name| {
|
|
ortsys![unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size)?];
|
|
Ok(())
|
|
})?;
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> Result<Self> {
|
|
let ptr = self.ptr_mut();
|
|
with_cstr(denotation.as_ref().as_bytes(), &|denotation| {
|
|
ortsys![unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size)?];
|
|
Ok(())
|
|
})?;
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> Result<Self> {
|
|
self.prepacked_weights = Some(weights.clone());
|
|
Ok(self)
|
|
}
|
|
|
|
/// Configures this environment to use its own thread pool instead of defaulting to the
|
|
/// [`Environment`](crate::environment::Environment)'s global thread pool if one was defined.
|
|
pub fn with_independent_thread_pool(mut self) -> Result<Self> {
|
|
self.no_global_thread_pool = true;
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
|
let manager = Rc::new(manager);
|
|
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
|
ortsys![unsafe SessionOptionsSetCustomCreateThreadFn(self.ptr_mut(), Some(environment::thread_create::<T>))?];
|
|
ortsys![unsafe SessionOptionsSetCustomJoinThreadFn(self.ptr_mut(), Some(environment::thread_join::<T>))?];
|
|
self.thread_manager = Some(manager as Rc<dyn Any>);
|
|
Ok(self)
|
|
}
|
|
}
|
|
|
|
/// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially
|
|
/// graph-level transformations, ranging from small graph simplifications and node eliminations to more complex node
|
|
/// fusions and layout optimizations.
|
|
///
|
|
/// Graph optimizations are divided in several categories (or levels) based on their complexity and functionality. They
|
|
/// can be performed either online or offline. In online mode, the optimizations are done before performing the
|
|
/// inference, while in offline mode, the runtime saves the optimized graph to disk (most commonly used when converting
|
|
/// an ONNX model to an ONNX Runtime model).
|
|
///
|
|
/// The optimizations belonging to one level are performed after the optimizations of the previous level have been
|
|
/// applied (e.g., extended optimizations are applied after basic optimizations have been applied).
|
|
///
|
|
/// **All optimizations (i.e. [`GraphOptimizationLevel::Level3`]) are enabled by default.**
|
|
///
|
|
/// # Online/offline mode
|
|
/// All optimizations can be performed either online or offline. In online mode, when initializing an inference session,
|
|
/// we also apply all enabled graph optimizations before performing model inference. Applying all optimizations each
|
|
/// time we initiate a session can add overhead to the model startup time (especially for complex models), which can be
|
|
/// critical in production scenarios. This is where the offline mode can bring a lot of benefit. In offline mode, after
|
|
/// performing graph optimizations, ONNX Runtime serializes the resulting model to disk. Subsequently, we can reduce
|
|
/// startup time by using the already optimized model and disabling all optimizations.
|
|
///
|
|
/// ## Notes:
|
|
/// - When running in offline mode, make sure to use the exact same options (e.g., execution providers, optimization
|
|
/// level) and hardware as the target machine that the model inference will run on (e.g., you cannot run a model
|
|
/// pre-optimized for a GPU execution provider on a machine that is equipped only with CPU).
|
|
/// - When layout optimizations are enabled, the offline mode can only be used on compatible hardware to the environment
|
|
/// when the offline model is saved. For example, if model has layout optimized for AVX2, the offline model would
|
|
/// require CPUs that support AVX2.
|
|
#[derive(Debug)]
|
|
pub enum GraphOptimizationLevel {
|
|
/// Disables all graph optimizations.
|
|
Disable,
|
|
/// Level 1 includes semantics-preserving graph rewrites which remove redundant nodes and redundant computation.
|
|
/// They run before graph partitioning and thus apply to all the execution providers. Available basic/level 1 graph
|
|
/// optimizations are as follows:
|
|
///
|
|
/// - Constant Folding: Statically computes parts of the graph that rely only on constant initializers. This
|
|
/// eliminates the need to compute them during runtime.
|
|
/// - Redundant node eliminations: Remove all redundant nodes without changing the graph structure. The following
|
|
/// such optimizations are currently supported:
|
|
/// * Identity Elimination
|
|
/// * Slice Elimination
|
|
/// * Unsqueeze Elimination
|
|
/// * Dropout Elimination
|
|
/// - Semantics-preserving node fusions : Fuse/fold multiple nodes into a single node. For example, Conv Add fusion
|
|
/// folds the Add operator as the bias of the Conv operator. The following such optimizations are currently
|
|
/// supported:
|
|
/// * Conv Add Fusion
|
|
/// * Conv Mul Fusion
|
|
/// * Conv BatchNorm Fusion
|
|
/// * Relu Clip Fusion
|
|
/// * Reshape Fusion
|
|
Level1,
|
|
#[rustfmt::skip]
|
|
/// Level 2 optimizations include complex node fusions. They are run after graph partitioning and are only applied to
|
|
/// the nodes assigned to the CPU or CUDA execution provider. Available extended/level 2 graph optimizations are as follows:
|
|
///
|
|
/// | Optimization | EPs | Comments |
|
|
/// |:------------------------------- |:--------- |:------------------------------------------------------------------------------ |
|
|
/// | GEMM Activation Fusion | CPU | |
|
|
/// | Matmul Add Fusion | CPU | |
|
|
/// | Conv Activation Fusion | CPU | |
|
|
/// | GELU Fusion | CPU, CUDA | |
|
|
/// | Layer Normalization Fusion | CPU, CUDA | |
|
|
/// | BERT Embedding Layer Fusion | CPU, CUDA | Fuses BERT embedding layers, layer normalization, & attention mask length |
|
|
/// | Attention Fusion* | CPU, CUDA | |
|
|
/// | Skip Layer Normalization Fusion | CPU, CUDA | Fuse bias of fully connected layers, skip connections, and layer normalization |
|
|
/// | Bias GELU Fusion | CPU, CUDA | Fuse bias of fully connected layers & GELU activation |
|
|
/// | GELU Approximation* | CUDA | Disabled by default; enable with `OrtSessionOptions::EnableGeluApproximation` |
|
|
///
|
|
/// > **NOTE**: To optimize performance of the BERT model, approximation is used in GELU Approximation and Attention
|
|
/// > Fusion for the CUDA execution provider. The impact on accuracy is negligible based on our evaluation; F1 score
|
|
/// > for a BERT model on SQuAD v1.1 is almost the same (87.05 vs 87.03).
|
|
Level2,
|
|
/// Level 3 optimizations include memory layout optimizations, which may optimize the graph to use the NCHWc memory
|
|
/// layout rather than NCHW to improve spatial locality for some targets.
|
|
Level3
|
|
}
|
|
|
|
impl From<GraphOptimizationLevel> for ort_sys::GraphOptimizationLevel {
|
|
fn from(val: GraphOptimizationLevel) -> Self {
|
|
match val {
|
|
GraphOptimizationLevel::Disable => ort_sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
|
|
GraphOptimizationLevel::Level1 => ort_sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
|
|
GraphOptimizationLevel::Level2 => ort_sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
|
|
GraphOptimizationLevel::Level3 => ort_sys::GraphOptimizationLevel::ORT_ENABLE_ALL
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct PrepackedWeightsInner(*mut ort_sys::OrtPrepackedWeightsContainer);
|
|
|
|
impl Drop for PrepackedWeightsInner {
|
|
fn drop(&mut self) {
|
|
ortsys![unsafe ReleasePrepackedWeightsContainer(self.0)];
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct PrepackedWeights {
|
|
inner: Arc<PrepackedWeightsInner>
|
|
}
|
|
|
|
impl PrepackedWeights {
|
|
#[allow(clippy::new_without_default)]
|
|
pub fn new() -> Self {
|
|
let mut ptr: *mut ort_sys::OrtPrepackedWeightsContainer = ptr::null_mut();
|
|
ortsys![unsafe CreatePrepackedWeightsContainer(&mut ptr).expect("")];
|
|
Self {
|
|
inner: Arc::new(PrepackedWeightsInner(ptr))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsPointer for PrepackedWeights {
|
|
type Sys = ort_sys::OrtPrepackedWeightsContainer;
|
|
|
|
fn ptr(&self) -> *const Self::Sys {
|
|
self.inner.0
|
|
}
|
|
|
|
fn ptr_mut(&mut self) -> *mut Self::Sys {
|
|
self.inner.0
|
|
}
|
|
}
|