feat: recover from SessionBuilder errors

This commit is contained in:
Carson M.
2026-02-28 00:33:24 -06:00
parent 831422c9d4
commit fb29790415
19 changed files with 520 additions and 320 deletions

View File

@@ -22,7 +22,7 @@
//! # use std::path::PathBuf;
//! # use ort::{compiler::ModelCompiler, session::Session, ep};
//! # fn main() -> ort::Result<()> {
//! let session_options = Session::builder()?.with_execution_providers([ep::CoreML::default()
//! let mut session_options = Session::builder()?.with_execution_providers([ep::CoreML::default()
//! .with_model_format(ep::coreml::ModelFormat::MLProgram)
//! .build()])?;
//!

View File

@@ -263,7 +263,7 @@ impl Model {
Ok(())
}
pub fn into_session(self, mut options: SessionBuilder) -> Result<Session> {
pub fn into_session(self, options: &SessionBuilder) -> Result<Session> {
let mut session_ptr = ptr::null_mut();
ortsys![@editor:
unsafe CreateSessionFromModel(

View File

@@ -32,7 +32,7 @@ fn test_identity_graph() -> Result<()> {
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
model.add_graph(graph)?;
let mut session = model.into_session(SessionBuilder::new()?)?;
let mut session = model.into_session(&SessionBuilder::new()?)?;
let output = session
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![1.0f32; 5]))?])?
.remove("output")
@@ -76,7 +76,7 @@ fn test_mul_graph() -> Result<()> {
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
model.add_graph(graph)?;
let mut session = model.into_session(SessionBuilder::new()?)?;
let mut session = model.into_session(&SessionBuilder::new()?)?;
let output = session
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![2.0f32; 5]))?])?
.remove("output")

View File

@@ -101,9 +101,9 @@ impl Environment {
/// Sets the global log level.
///
/// ```
/// # use ort::{environment::Environment, logging::LogLevel};
/// # fn main() -> ort::Result<()> {
/// # use ort::logging::LogLevel;
/// let env = ort::environment::get_environment()?;
/// let env = Environment::current()?;
///
/// env.set_log_level(LogLevel::Warning);
/// # Ok(())
@@ -130,8 +130,9 @@ impl Environment {
/// no longer be needed.
///
/// ```
/// # use ort::environment::Environment;
/// # fn main() -> ort::Result<()> {
/// let env = ort::environment::get_environment()?;
/// let env = Environment::current()?;
///
/// let _ = env.register_ep_library("CUDA", "/path/to/onnxruntime_providers_cuda.dll");
/// # Ok(())

View File

@@ -45,7 +45,7 @@ impl ExecutionProvider for ACL {
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, enable_fast_math: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
return Ok(unsafe {
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.fast_math.into()))
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.fast_math.into()))
}?);
}

View File

@@ -47,7 +47,7 @@ impl ExecutionProvider for ArmNN {
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
return Ok(unsafe {
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()))
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()))
}?);
}

View File

@@ -102,11 +102,11 @@ impl ExecutionProvider for DirectML {
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
#[cfg(any(feature = "load-dynamic", feature = "directml"))]
{
use crate::AsPointer;
use crate::{AsPointer, Error};
let api = get_dml_api()?;
if let Some(device_id) = self.device_id {
unsafe { crate::error::status_to_result((api.SessionOptionsAppendExecutionProvider_DML)(session_builder.ptr_mut(), device_id as _)) }?;
unsafe { Error::result_from_status((api.SessionOptionsAppendExecutionProvider_DML)(session_builder.ptr_mut(), device_id as _)) }?;
} else {
let device_options = ort_sys::OrtDmlDeviceOptions {
Filter: match self.device_filter {
@@ -120,7 +120,7 @@ impl ExecutionProvider for DirectML {
PerformancePreference::MinimumPower => ort_sys::OrtDmlPerformancePreference::MinimumPower
}
};
unsafe { crate::error::status_to_result((api.SessionOptionsAppendExecutionProvider_DML2)(session_builder.ptr_mut(), &device_options)) }?;
unsafe { Error::result_from_status((api.SessionOptionsAppendExecutionProvider_DML2)(session_builder.ptr_mut(), &device_options)) }?;
}
return Ok(());

View File

@@ -78,7 +78,7 @@ impl ExecutionProvider for NNAPI {
if self.cpu_only {
flags |= 0x008;
}
return Ok(unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags)) }?);
return Ok(unsafe { crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags)) }?);
}
Err(RegisterError::MissingFeature)

View File

@@ -22,7 +22,7 @@ impl ExecutionProvider for RKNPU {
use crate::AsPointer;
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_RKNPU(options: *mut ort_sys::OrtSessionOptions) -> ort_sys::OrtStatusPtr);
return Ok(unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_RKNPU(session_builder.ptr_mut())) }?);
return Ok(unsafe { crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_RKNPU(session_builder.ptr_mut())) }?);
}
Err(RegisterError::MissingFeature)

View File

@@ -100,7 +100,7 @@ impl ExecutionProvider for TVM {
}
let options_string = alloc::ffi::CString::new(option_string.join(",")).expect("invalid option string");
return Ok(unsafe {
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
}?);
}

View File

@@ -3,11 +3,17 @@ use alloc::{
format,
string::{String, ToString}
};
use core::{convert::Infallible, error::Error as CoreError, ffi::c_char, fmt, ptr};
use core::{
convert::Infallible,
error::Error as CoreError,
ffi::c_char,
fmt,
ptr::{self, NonNull}
};
use crate::{
ortsys,
util::{char_p_to_string, with_cstr}
util::{char_p_to_string, cold, with_cstr}
};
/// Type alias for the `Result` type returned by `ort` functions.
@@ -27,61 +33,164 @@ impl<T> IntoStatus for Result<T, Error> {
}
}
/// An error returned by any `ort` API.
#[derive(Debug)]
pub struct Error {
struct ErrorInternal {
code: ErrorCode,
msg: String
message: String,
cause: Option<Box<dyn CoreError + Send + Sync + 'static>>,
status_ptr: NonNull<ort_sys::OrtStatus>
}
impl Error {
/// Wrap a custom, user-provided error in an [`ort::Error`](Error)..
unsafe impl Send for ErrorInternal {}
unsafe impl Sync for ErrorInternal {}
impl ErrorInternal {
#[cold]
pub(crate) unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtStatus>) -> Self {
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(ptr.as_ptr())]);
let raw: *const c_char = ortsys![unsafe GetErrorMessage(ptr.as_ptr())];
match char_p_to_string(raw) {
Ok(message) => ErrorInternal {
code,
message,
cause: None,
status_ptr: ptr
},
Err(err) => ErrorInternal {
code,
message: format!("(failed to convert UTF-8: {err})"),
cause: None,
status_ptr: ptr
}
}
}
}
impl Drop for ErrorInternal {
fn drop(&mut self) {
ortsys![unsafe ReleaseStatus(self.status_ptr.as_ptr())];
}
}
/// An error returned by any `ort` API.
pub struct Error<R = ()> {
recover: R,
inner: Box<ErrorInternal>
}
impl<R> fmt::Debug for Error<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Error")
.field("code", &self.inner.code)
.field("message", &self.message())
.field("ptr", &self.inner.status_ptr.as_ptr())
.finish()
}
}
impl Error<()> {
/// Converts an [`ort_sys::OrtStatusPtr`] to a [`Result`].
///
/// This takes ownership of the status pointer.
///
/// # Safety
/// `ptr` must be a valid `OrtStatusPtr` returned from an `ort-sys` API.
#[inline]
pub unsafe fn result_from_status(ptr: ort_sys::OrtStatusPtr) -> Result<(), Self> {
match NonNull::new(ptr.0) {
None => Ok(()),
Some(ptr) => {
cold();
Err(Self {
recover: (),
inner: Box::new(unsafe { ErrorInternal::from_ptr(ptr) })
})
}
}
}
/// Wrap a custom, user-provided error in an [`ort::Error`](Error).
///
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
/// related operation fails.
pub fn wrap<T: CoreError + Send + Sync + 'static>(err: T) -> Self {
Error {
code: ErrorCode::GenericFailure,
msg: err.to_string()
}
Self::new_internal(ErrorCode::GenericFailure, err.to_string(), Some(Box::new(err)))
}
/// Creates a custom [`Error`] with the given message.
pub fn new(msg: impl Into<String>) -> Self {
Error {
code: ErrorCode::GenericFailure,
msg: msg.into()
}
Self::new_internal(ErrorCode::GenericFailure, msg, None)
}
/// Creates a custom [`Error`] with the given [`ErrorCode`] and message.
pub fn new_with_code(code: ErrorCode, msg: impl Into<String>) -> Self {
Error { code, msg: msg.into() }
Self::new_internal(code, msg, None)
}
fn new_internal(code: ErrorCode, message: impl Into<String>, cause: Option<Box<dyn CoreError + Send + Sync + 'static>>) -> Self {
let message = message.into();
let ptr = with_cstr(message.as_bytes(), &|message| Ok(ortsys![unsafe CreateStatus(code.into(), message.as_ptr())])).expect("invalid error message");
Self {
recover: (),
inner: Box::new(ErrorInternal {
code,
message,
cause,
status_ptr: unsafe { NonNull::new_unchecked(ptr.0) }
})
}
}
pub(crate) fn with_recover<R>(self, recover: R) -> Error<R> {
Error { recover, inner: self.inner }
}
}
impl<R> Error<R> {
pub fn code(&self) -> ErrorCode {
self.code
self.inner.code
}
pub fn message(&self) -> &str {
self.msg.as_str()
self.inner.message.as_str()
}
}
impl fmt::Display for Error {
impl<R: Sized> Error<R> {
/// Recovers from this error.
///
/// ```
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?
/// .with_optimization_level(GraphOptimizationLevel::All)
/// // Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
/// .unwrap_or_else(|e| e.recover())
/// .commit_from_file("tests/data/upsample.onnx")?;
/// # Ok(())
/// # }
/// ```
#[inline]
pub fn recover(self) -> R {
self.recover
}
}
impl<R> fmt::Display for Error<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.msg)
f.write_str(&self.inner.message)
}
}
impl CoreError for Error {}
impl<R> CoreError for Error<R> {
fn source(&self) -> Option<&(dyn CoreError + 'static)> {
self.inner.cause.as_ref().map(|x| &**x as &dyn CoreError)
}
}
impl From<Box<dyn CoreError + Send + Sync + 'static>> for Error {
fn from(err: Box<dyn CoreError + Send + Sync + 'static>) -> Self {
Error {
code: ErrorCode::GenericFailure,
msg: err.to_string()
}
Error::new_internal(ErrorCode::GenericFailure, err.to_string(), Some(err))
}
}
@@ -121,6 +230,12 @@ impl From<alloc::ffi::IntoStringError> for Error {
}
}
impl From<Error<crate::session::builder::SessionBuilder>> for Error<()> {
fn from(err: Error<crate::session::builder::SessionBuilder>) -> Self {
Self { recover: (), inner: err.inner }
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[non_exhaustive]
pub enum ErrorCode {
@@ -177,39 +292,3 @@ impl From<ErrorCode> for ort_sys::OrtErrorCode {
}
}
}
/// Converts an [`ort_sys::OrtStatusPtr`] to a [`Result`].
///
/// **Note that this frees `status`!**
///
/// # Safety
/// The value contained in `status` must be a valid [`ort_sys::OrtStatus`] pointer, or a null pointer (in which case the
/// result will be `Ok`).
#[inline]
pub unsafe fn status_to_result(status: ort_sys::OrtStatusPtr) -> Result<(), Error> {
let status = status.0;
if status.is_null() {
Ok(())
} else {
#[cold]
fn status_to_error(status: *mut ort_sys::OrtStatus) -> Error {
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(status)]);
let raw: *const c_char = ortsys![unsafe GetErrorMessage(status)];
match char_p_to_string(raw) {
Ok(msg) => {
ortsys![unsafe ReleaseStatus(status)];
Error { code, msg }
}
Err(err) => {
ortsys![unsafe ReleaseStatus(status)];
Error {
code,
msg: format!("(failed to convert UTF-8: {err})")
}
}
}
}
Err(status_to_error(status))
}
}

View File

@@ -288,19 +288,19 @@ macro_rules! ortsys {
unsafe { ($crate::api().$method)($($n),*) }
};
(@ort: unsafe $method:ident($($n:expr),*) as Result) => {
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }
unsafe { $crate::error::Error::result_from_status(($crate::api().$method)($($n),+)) }
};
(@$api:ident: unsafe $method:ident($($n:expr),*)) => {
unsafe { ($crate::api::$api().unwrap().$method)($($n),+) }
};
(@$api:ident: unsafe $method:ident($($n:expr),*)?) => {
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })?
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })?
};
(@$api:ident: unsafe $method:ident($($n:expr),*)?; nonNull($($check:ident),+)$(;)?) => {
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })?;
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })?;
ortsys![@nonNull?; $($check),+];
};
(@$api:ident: unsafe $method:ident($($n:expr),*) as Result) => {
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })
};
}

View File

@@ -13,6 +13,7 @@ use std::sync::Mutex;
use smallvec::SmallVec;
use crate::{
Error,
error::Result,
session::{SessionOutputs, SharedSessionInner, UntypedRunOptions},
util::{STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS},
@@ -121,7 +122,7 @@ pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mu
crate::logging::drop!(AsyncInferenceContext, user_data);
if let Err(e) = unsafe { crate::error::status_to_result(status) } {
if let Err(e) = unsafe { Error::result_from_status(status) } {
ctx.inner.emplace_value(Err(e));
ctx.inner.wake();
return;

View File

@@ -35,90 +35,10 @@ impl SessionBuilder {
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))]
pub fn commit_from_url(&mut self, model_url: impl AsRef<str>) -> Result<Session> {
let downloaded_path = SessionBuilder::download(model_url.as_ref())?;
let downloaded_path = download_model(model_url.as_ref())?;
self.commit_from_file(downloaded_path)
}
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
fn download(url: &str) -> Result<PathBuf> {
use ureq::{
config::Config,
tls::{RootCerts, TlsConfig, TlsProvider}
};
let mut download_dir = ort_sys::internal::dirs::cache_dir()
.expect("could not determine cache directory")
.join("models");
if std::fs::create_dir_all(&download_dir).is_err() {
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
}
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
let _ = write!(&mut s, "{:02x}", b);
s
});
let model_filepath = download_dir.join(&model_filename);
if model_filepath.exists() {
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
let agent = Config::builder()
.tls_config(
TlsConfig::builder()
.root_certs(RootCerts::WebPki)
.provider(if cfg!(any(feature = "tls-rustls", feature = "tls-rustls-no-provider")) {
TlsProvider::Rustls
} else if cfg!(any(feature = "tls-native", feature = "tls-native-vendored")) {
TlsProvider::NativeTls
} else {
return Err(Error::new(
"No TLS provider configured. When using `fetch-models` with HTTPS URLs, a `tls-*` feature must be enabled."
));
})
.build()
)
.build()
.new_agent();
let resp = agent.get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;
let len = resp
.headers()
.get("Content-Length")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.expect("Missing Content-Length header");
crate::info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_body().into_with_config().limit(u64::MAX).reader();
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
let mut writer = std::io::BufWriter::new(f);
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
if bytes_io_count != len as u64 {
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
}
drop(writer);
match std::fs::rename(&temp_filepath, &model_filepath) {
Ok(()) => Ok(model_filepath),
Err(e) => {
if model_filepath.exists() {
let _ = std::fs::remove_file(temp_filepath);
Ok(model_filepath)
} else {
Err(Error::new(format!("Failed to download model: {e}")))
}
}
}
}
}
/// Loads an ONNX model from a file and builds the session.
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
@@ -331,3 +251,81 @@ impl SessionBuilder {
EditableSession::new(session_ptr, self)
}
}
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
fn download_model(url: &str) -> Result<PathBuf> {
use ureq::{
config::Config,
tls::{RootCerts, TlsConfig, TlsProvider}
};
let mut download_dir = ort_sys::internal::dirs::cache_dir()
.expect("could not determine cache directory")
.join("models");
if std::fs::create_dir_all(&download_dir).is_err() {
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
}
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
let _ = write!(&mut s, "{:02x}", b);
s
});
let model_filepath = download_dir.join(&model_filename);
if model_filepath.exists() {
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
let agent = Config::builder()
.tls_config(
TlsConfig::builder()
.root_certs(RootCerts::WebPki)
.provider(if cfg!(any(feature = "tls-rustls", feature = "tls-rustls-no-provider")) {
TlsProvider::Rustls
} else if cfg!(any(feature = "tls-native", feature = "tls-native-vendored")) {
TlsProvider::NativeTls
} else {
return Err(Error::new("No TLS provider configured. When using `fetch-models` with HTTPS URLs, a `tls-*` feature must be enabled."));
})
.build()
)
.build()
.new_agent();
let resp = agent.get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;
let len = resp
.headers()
.get("Content-Length")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.expect("Missing Content-Length header");
crate::info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_body().into_with_config().limit(u64::MAX).reader();
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
let mut writer = std::io::BufWriter::new(f);
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
if bytes_io_count != len as u64 {
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
}
drop(writer);
match std::fs::rename(&temp_filepath, &model_filepath) {
Ok(()) => Ok(model_filepath),
Err(e) => {
if model_filepath.exists() {
let _ = std::fs::remove_file(temp_filepath);
Ok(model_filepath)
} else {
Err(Error::new(format!("Failed to download model: {e}")))
}
}
}
}
}

View File

@@ -1,5 +1,4 @@
use super::SessionBuilder;
use crate::Result;
use super::{BuilderResult, SessionBuilder};
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -7,17 +6,21 @@ impl SessionBuilder {
/// Enable/disable the usage of prepacking.
///
/// This option is **enabled** by default.
pub fn with_prepacking(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?;
Ok(self)
pub fn with_prepacking(mut self, enable: bool) -> BuilderResult {
match self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" }) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Use allocators from the registered environment.
///
/// This option is **disabled** by default.
pub fn with_env_allocators(mut self) -> Result<Self> {
self.add_config_entry("session.use_env_allocators", "1")?;
Ok(self)
pub fn with_env_allocators(mut self) -> BuilderResult {
match self.add_config_entry("session.use_env_allocators", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Disables subnormal floats by enabling the denormals-are-zero and flush-to-zero flags for all threads in the
@@ -28,94 +31,137 @@ impl SessionBuilder {
/// giving faster & more consistent performance, but lower accuracy (in cases where subnormals are involved).
///
/// This option is **disabled** by default, as it may hurt model accuracy.
pub fn with_flush_to_zero(mut self) -> Result<Self> {
self.add_config_entry("session.set_denormal_as_zero", "1")?;
Ok(self)
pub fn with_flush_to_zero(mut self) -> BuilderResult {
match self.add_config_entry("session.set_denormal_as_zero", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable/disable fusion for quantized models in QDQ (`QuantizeLinear`/`DequantizeLinear`) format.
///
/// This option is **enabled** by default for all EPs except DirectML.
pub fn with_quant_qdq(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?;
Ok(self)
pub fn with_quant_qdq(mut self, enable: bool) -> BuilderResult {
match self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" }) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable/disable the optimization step removing double QDQ nodes.
///
/// This option is **enabled** by default.
pub fn with_double_qdq_remover(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?;
Ok(self)
pub fn with_double_qdq_remover(mut self, enable: bool) -> BuilderResult {
match self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" }) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed.
///
/// This option is **disabled** by default.
pub fn with_qdq_cleanup(mut self) -> Result<Self> {
self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?;
Ok(self)
pub fn with_qdq_cleanup(mut self) -> BuilderResult {
match self.add_config_entry("session.enable_quant_qdq_cleanup", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable fast tanh-based GELU approximation (like PyTorch's `nn.GELU(approximate='tanh')`).
///
/// This option is **disabled** by default, as it may impact results.
pub fn with_approximate_gelu(mut self) -> Result<Self> {
self.add_config_entry("optimization.enable_gelu_approximation", "1")?;
Ok(self)
pub fn with_approximate_gelu(mut self) -> BuilderResult {
match self.add_config_entry("optimization.enable_gelu_approximation", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable the `Cast` chain elimination optimization.
///
/// This option is **disabled** by default, as it may impact results.
pub fn with_cast_chain_elimination(mut self) -> Result<Self> {
self.add_config_entry("optimization.enable_cast_chain_elimination", "1")?;
Ok(self)
pub fn with_cast_chain_elimination(mut self) -> BuilderResult {
match self.add_config_entry("optimization.enable_cast_chain_elimination", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable/disable ahead-of-time function inlining.
///
/// This option is **enabled** by default.
pub fn with_aot_inlining(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?;
Ok(self)
pub fn with_aot_inlining(mut self, enable: bool) -> BuilderResult {
match self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" }) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Accepts a comma-separated list of optimizers to disable.
pub fn with_disabled_optimizers(mut self, optimizers: impl AsRef<str>) -> Result<Self> {
self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?;
Ok(self)
pub fn with_disabled_optimizers(mut self, optimizers: impl AsRef<str>) -> BuilderResult {
match self.add_config_entry("optimization.disable_specified_optimizers", optimizers) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable using the device allocator for allocating initialized tensor memory, potentially bypassing arena
/// allocators.
///
/// This option is **disabled** by default.
pub fn with_device_allocated_initializers(mut self) -> Result<Self> {
self.add_config_entry("session.use_device_allocator_for_initializers", "1")?;
Ok(self)
pub fn with_device_allocated_initializers(mut self) -> BuilderResult {
match self.add_config_entry("session.use_device_allocator_for_initializers", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable/disable allowing the inter-op threads to spin for a short period before blocking.
///
/// This option is **enabled** by defualt.
pub fn with_inter_op_spinning(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?;
Ok(self)
pub fn with_inter_op_spinning(mut self, enable: bool) -> BuilderResult {
match self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" }) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable/disable allowing the intra-op threads to spin for a short period before blocking.
///
/// This option is **enabled** by defualt.
pub fn with_intra_op_spinning(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?;
Ok(self)
pub fn with_intra_op_spinning(mut self, enable: bool) -> BuilderResult {
match self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" }) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Disables falling back to the CPU for operations not supported by any other EP.
/// Models with graphs that cannot be placed entirely on the EP(s) will fail to commit.
pub fn with_disable_cpu_fallback(mut self) -> Result<Self> {
self.add_config_entry("session.disable_cpu_ep_fallback", "1")?;
Ok(self)
pub fn with_disable_cpu_fallback(mut self) -> BuilderResult {
match self.add_config_entry("session.disable_cpu_ep_fallback", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Uses slower U8U8 matrix multiplication in place of U8S8 matrix multiplication that could potentially overflow on
/// x86-64 platforms without the VNNI extension.
///
/// This should only be enabled if you encounter overflow issues with quantized models.
pub fn with_precise_qmm(mut self) -> BuilderResult {
match self.add_config_entry("session.x64quantprecision", "1") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enables dynamic thread block sizing with the given base block size.
pub fn with_dynamic_block_base(mut self, size: u32) -> BuilderResult {
match self.add_config_entry("session.dynamic_block_base", size.to_string()) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
}

View File

@@ -7,14 +7,13 @@ use core::{
#[cfg(feature = "std")]
use std::path::Path;
use super::SessionBuilder;
use super::{BuilderResult, SessionBuilder};
#[cfg(feature = "std")]
use crate::util::path_to_os_char;
use crate::{
AsPointer, Error, ErrorCode,
environment::{self, ThreadManager},
ep::{ExecutionProviderDispatch, apply_execution_providers},
error::Result,
logging::{LogLevel, LoggerFunction},
memory::MemoryInfo,
operator::OperatorDomain,
@@ -36,9 +35,11 @@ impl SessionBuilder {
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
/// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec`
/// of [`ExecutionProviderDispatch`]es.
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result<Self> {
apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?;
Ok(self)
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> BuilderResult {
match apply_execution_providers(&mut self, execution_providers.as_ref(), "session options") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Configure the session to use a number of threads to parallelize the execution within nodes. If ONNX Runtime was
@@ -48,9 +49,11 @@ impl SessionBuilder {
///
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
/// [`SessionBuilder::with_inter_threads()`].
pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
Ok(self)
pub fn with_intra_threads(mut self, num_threads: usize) -> BuilderResult {
match ortsys![@ort: unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Configure the session to use a number of threads to parallelize the execution of the graph. If nodes can be run
@@ -60,9 +63,11 @@ impl SessionBuilder {
///
/// For configuring the number of threads used to parallelize the execution within nodes, see
/// [`SessionBuilder::with_intra_threads()`].
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
Ok(self)
pub fn with_inter_threads(mut self, num_threads: usize) -> BuilderResult {
match ortsys![@ort: unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enable/disable the parallel execution mode for this session. By default, this is disabled.
@@ -70,21 +75,25 @@ impl SessionBuilder {
/// Parallel execution can improve performance for models with many branches, at the cost of higher memory usage.
/// You can configure the amount of threads used to parallelize the execution of the graph via
/// [`SessionBuilder::with_inter_threads()`].
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> Result<Self> {
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> BuilderResult {
let execution_mode = if parallel_execution {
ort_sys::ExecutionMode::ORT_PARALLEL
} else {
ort_sys::ExecutionMode::ORT_SEQUENTIAL
};
ortsys![unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode)?];
Ok(self)
match ortsys![@ort: unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Set the session's optimization level. See [`GraphOptimizationLevel`] for more information on the different
/// optimization levels.
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Result<Self> {
ortsys![unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into())?];
Ok(self)
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> BuilderResult {
match ortsys![@ort: unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// After performing optimization (configurable with [`SessionBuilder::with_optimization_level`]), serializes the
@@ -93,10 +102,12 @@ impl SessionBuilder {
/// Note that the file will only be created after the model is committed.
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> Result<Self> {
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> BuilderResult {
let path = crate::util::path_to_os_char(path);
ortsys![unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr())?];
Ok(self)
match ortsys![@ort: unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
@@ -105,29 +116,37 @@ impl SessionBuilder {
/// [`Session::end_profiling`]: crate::session::Session::end_profiling
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> Result<Self> {
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> BuilderResult {
let profiling_file = crate::util::path_to_os_char(profiling_file);
ortsys![unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr())?];
Ok(self)
match ortsys![@ort: unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enables/disables memory pattern optimization. Disable it if the input size varies, i.e., dynamic batch
pub fn with_memory_pattern(mut self, enable: bool) -> Result<Self> {
if enable {
ortsys![unsafe EnableMemPattern(self.ptr_mut())?];
pub fn with_memory_pattern(mut self, enable: bool) -> BuilderResult {
let result = if enable {
ortsys![@ort: unsafe EnableMemPattern(self.ptr_mut()) as Result]
} else {
ortsys![unsafe DisableMemPattern(self.ptr_mut())?];
ortsys![@ort: unsafe DisableMemPattern(self.ptr_mut()) as Result]
};
match result {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
Ok(self)
}
/// Configure this session to use a custom allocator, rather than the global default. This allocator is responsible
/// for allocating the *metadata* associated with values -- not the contents of the values themselves; that's
/// handled by the active execution providers. As such, only CPU-accessible allocators are allowed.
pub fn with_allocator(mut self, info: MemoryInfo) -> Result<Self> {
pub fn with_allocator(mut self, info: MemoryInfo) -> BuilderResult {
if !info.is_cpu_accessible() {
return Err(Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator"));
return Err(
Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator").with_recover(self)
);
}
self.memory_info = Some(Arc::new(info));
Ok(self)
}
@@ -135,117 +154,145 @@ impl SessionBuilder {
/// Registers a custom operator library at the given library path.
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> Result<Self> {
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> BuilderResult {
let path_cstr = path_to_os_char(lib_path);
ortsys![unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr())?];
Ok(self)
match ortsys![@ort: unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators.
pub fn with_extensions(mut self) -> Result<Self> {
ortsys![unsafe EnableOrtCustomOps(self.ptr_mut())?];
Ok(self)
pub fn with_extensions(mut self) -> BuilderResult {
match ortsys![@ort: unsafe EnableOrtCustomOps(self.ptr_mut()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> Result<Self> {
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> BuilderResult {
let domain: Arc<OperatorDomain> = domain.into();
ortsys![unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut())?];
self.operator_domains.push(domain);
Ok(self)
match ortsys![@ort: unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut()) as Result] {
Ok(()) => {
self.operator_domains.push(domain);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
/// Enables/disables deterministic computation.
///
/// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance.
/// Enabling deterministic compute will output reproducible results, but may come at a performance penalty.
pub fn with_deterministic_compute(mut self, enable: bool) -> Result<Self> {
ortsys![unsafe SetDeterministicCompute(self.ptr_mut(), enable)?];
Ok(self)
pub fn with_deterministic_compute(mut self, enable: bool) -> BuilderResult {
match ortsys![@ort: unsafe SetDeterministicCompute(self.ptr_mut(), enable) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> Result<Self> {
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
let ptr = self.ptr_mut();
let value: Arc<DynValue> = value.into();
with_cstr(name.as_ref().as_bytes(), &|name| {
ortsys![unsafe AddInitializer(ptr, name.as_ptr(), value.ptr())?];
Ok(())
})?;
self.initializers.push(value);
Ok(self)
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddInitializer(ptr, name.as_ptr(), value.ptr()) as Result]) {
Ok(()) => {
self.initializers.push(value);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> Result<Self> {
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
let ptr = self.ptr_mut();
let value: Arc<DynValue> = value.into();
with_cstr(name.as_ref().as_bytes(), &|name| {
ortsys![unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1)?];
Ok(())
})?;
self.initializers.push(value);
Ok(self)
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1) as Result]) {
Ok(()) => {
self.initializers.push(value);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
#[cfg(all(feature = "std", feature = "api-18"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "std", feature = "api-18"))))]
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> Result<Self> {
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> BuilderResult {
let file_name = path_to_os_char(file_name);
let sizes = [buffer.len()];
ortsys![unsafe AddExternalInitializersFromMemory(self.ptr_mut(), &file_name.as_ptr(), &buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(), sizes.as_ptr(), 1)?];
self.external_initializer_buffers.push(buffer);
Ok(self)
match ortsys![@ort:
unsafe AddExternalInitializersFromMemory(
self.ptr_mut(),
&file_name.as_ptr(),
&buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(),
sizes.as_ptr(),
1
) as Result
] {
Ok(()) => {
self.external_initializer_buffers.push(buffer);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_log_id(mut self, id: impl AsRef<str>) -> Result<Self> {
pub fn with_log_id(mut self, id: impl AsRef<str>) -> BuilderResult {
let ptr = self.ptr_mut();
with_cstr(id.as_ref().as_bytes(), &|id| {
ortsys![unsafe SetSessionLogId(ptr, id.as_ptr())?];
Ok(())
})?;
Ok(self)
match with_cstr(id.as_ref().as_bytes(), &|id| ortsys![@ort: unsafe SetSessionLogId(ptr, id.as_ptr()) as Result]) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> Result<Self> {
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> BuilderResult {
let ptr = self.ptr_mut();
with_cstr(name.as_ref().as_bytes(), &|name| {
ortsys![unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size)?];
Ok(())
})?;
Ok(self)
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size) as Result]) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> Result<Self> {
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> BuilderResult {
let ptr = self.ptr_mut();
with_cstr(denotation.as_ref().as_bytes(), &|denotation| {
ortsys![unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size)?];
Ok(())
})?;
Ok(self)
match with_cstr(denotation.as_ref().as_bytes(), &|denotation| ortsys![@ort: unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size) as Result])
{
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> Result<Self> {
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> BuilderResult {
self.prepacked_weights = Some(weights.clone());
Ok(self)
}
/// Configures this environment to use its own thread pool instead of defaulting to the
/// [`Environment`](crate::environment::Environment)'s global thread pool if one was defined.
pub fn with_independent_thread_pool(mut self) -> Result<Self> {
pub fn with_independent_thread_pool(mut self) -> BuilderResult {
self.no_global_thread_pool = true;
Ok(self)
}
pub fn with_no_environment_execution_providers(mut self) -> Result<Self> {
pub fn with_no_environment_execution_providers(mut self) -> BuilderResult {
self.no_env_eps = true;
Ok(self)
}
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> BuilderResult {
let manager = Arc::new(manager);
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
ortsys![unsafe SessionOptionsSetCustomCreateThreadFn(self.ptr_mut(), Some(environment::thread_create::<T>))?];
ortsys![unsafe SessionOptionsSetCustomJoinThreadFn(self.ptr_mut(), Some(environment::thread_join::<T>))?];
self.thread_manager = Some(manager as Arc<dyn Any>);
Ok(self)
let ptr = self.ptr_mut();
match ortsys![@ort: unsafe SessionOptionsSetCustomThreadCreationOptions(ptr, (&*manager as *const T) as *mut c_void) as Result]
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomCreateThreadFn(ptr, Some(environment::thread_create::<T>)) as Result])
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomJoinThreadFn(ptr, Some(environment::thread_join::<T>)) as Result])
{
Ok(()) => {
self.thread_manager = Some(manager as Arc<dyn Any>);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
/// Configures this session to use a custom logger function.
@@ -268,11 +315,15 @@ impl SessionBuilder {
/// # Ok(())
/// # }
/// ```
pub fn with_logger(mut self, logger: LoggerFunction) -> Result<Self> {
pub fn with_logger(mut self, logger: LoggerFunction) -> BuilderResult {
let logger = Arc::new(logger);
ortsys![unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void)?];
self.logger = Some(logger);
Ok(self)
match ortsys![@ort: unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void) as Result] {
Ok(()) => {
self.logger = Some(logger);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
/// Sets the severity level for messages logged by this session.
@@ -281,15 +332,19 @@ impl SessionBuilder {
/// precedence, i.e. if the application was initialized with `ort`'s log level set to `warn` via the `RUST_LOG`
/// environment variable or similar, setting a session's log severity level to `verbose` will still have it only
/// log `warn` messages or higher.`
pub fn with_log_level(mut self, level: LogLevel) -> Result<Self> {
ortsys![unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _)?];
Ok(self)
pub fn with_log_level(mut self, level: LogLevel) -> BuilderResult {
match ortsys![@ort: unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Controls the level of verbosity for messages logged under [`LogLevel::Verbose`]; higher values = more verbose.
pub fn with_log_verbosity(mut self, verbosity: c_int) -> Result<Self> {
ortsys![unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity)?];
Ok(self)
pub fn with_log_verbosity(mut self, verbosity: c_int) -> BuilderResult {
match ortsys![@ort: unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy) based
@@ -309,9 +364,11 @@ impl SessionBuilder {
/// ```
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> Result<Self> {
ortsys![unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into())?];
Ok(self)
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> BuilderResult {
match ortsys![@ort: unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
}

View File

@@ -34,6 +34,22 @@ mod impl_options;
pub use self::editable::*;
pub use self::impl_options::*;
/// `Result` type returned by [`SessionBuilder`] methods.
///
/// This type supports [error recovery](Error::recover):
/// ```
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?
/// .with_optimization_level(GraphOptimizationLevel::All)
/// // Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
/// .unwrap_or_else(|e| e.recover())
/// .commit_from_file("tests/data/upsample.onnx")?;
/// # Ok(())
/// # }
/// ```
pub type BuilderResult = Result<SessionBuilder, Error<SessionBuilder>>;
/// Creates a session using the builder pattern.
///
/// Once configured, use the
@@ -146,7 +162,7 @@ impl SessionBuilder {
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # use std::{thread, time::Duration};
/// # fn main() -> ort::Result<()> {
/// let builder = Session::builder()?
/// let mut builder = Session::builder()?
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
/// .with_intra_threads(1)?;
///
@@ -168,9 +184,11 @@ impl SessionBuilder {
}
/// Adds a custom configuration entry to the session.
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
self.add_config_entry(key.as_ref(), value.as_ref())?;
Ok(self)
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> BuilderResult {
match self.add_config_entry(key.as_ref(), value.as_ref()) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
}
@@ -203,7 +221,7 @@ impl LoadCanceler {
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # use std::{thread, time::Duration};
/// # fn main() -> ort::Result<()> {
/// let builder = Session::builder()?
/// let mut builder = Session::builder()?
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
/// .with_intra_threads(1)?;
///

View File

@@ -34,7 +34,7 @@ use smallvec::SmallVec;
use crate::{
AsPointer,
environment::Environment,
error::{Error, ErrorCode, Result, status_to_result},
error::{Error, ErrorCode, Result},
memory::Allocator,
ortsys,
util::{AllocatedString, STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr, with_cstr_ptr_array},
@@ -708,7 +708,7 @@ pub(crate) mod io {
) -> Result<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr.as_ptr(), &mut num_nodes) };
unsafe { status_to_result(status) }?;
unsafe { Error::result_from_status(status) }?;
Ok(num_nodes)
}
@@ -721,7 +721,7 @@ pub(crate) mod io {
let mut name_ptr: *mut c_char = ptr::null_mut();
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_ptr) };
unsafe { status_to_result(status) }?;
unsafe { Error::result_from_status(status) }?;
if name_ptr.is_null() {
crate::util::cold();
return Err(crate::Error::new("expected `name_ptr` to not be null"));
@@ -738,7 +738,7 @@ pub(crate) mod io {
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
unsafe { status_to_result(status) }?;
unsafe { Error::result_from_status(status) }?;
let Some(typeinfo_ptr) = NonNull::new(typeinfo_ptr) else {
crate::util::cold();
return Err(crate::Error::new("expected `typeinfo_ptr` to not be null"));

View File

@@ -11,7 +11,7 @@ use super::{Checkpoint, Optimizer, training_api};
use crate::{
AsPointer,
environment::Environment,
error::{Result, status_to_result},
error::{Error, Result},
memory::Allocator,
ortsys,
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
@@ -329,11 +329,11 @@ fn extract_io_names(
) -> ort_sys::OrtStatusPtr
) -> Result<Vec<String>> {
let mut count = 0;
unsafe { status_to_result(get_count(ptr.as_ptr(), &mut count)) }?;
unsafe { Error::result_from_status(get_count(ptr.as_ptr(), &mut count)) }?;
(0..count)
.map(|i| {
let mut name_bytes: *const c_char = ptr::null();
unsafe { status_to_result(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
unsafe { Error::result_from_status(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
let name = match char_p_to_string(name_bytes) {
Ok(name) => name,
Err(e) => {