mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: recover from SessionBuilder errors
This commit is contained in:
@@ -22,7 +22,7 @@
|
|||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! # use ort::{compiler::ModelCompiler, session::Session, ep};
|
//! # use ort::{compiler::ModelCompiler, session::Session, ep};
|
||||||
//! # fn main() -> ort::Result<()> {
|
//! # fn main() -> ort::Result<()> {
|
||||||
//! let session_options = Session::builder()?.with_execution_providers([ep::CoreML::default()
|
//! let mut session_options = Session::builder()?.with_execution_providers([ep::CoreML::default()
|
||||||
//! .with_model_format(ep::coreml::ModelFormat::MLProgram)
|
//! .with_model_format(ep::coreml::ModelFormat::MLProgram)
|
||||||
//! .build()])?;
|
//! .build()])?;
|
||||||
//!
|
//!
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ impl Model {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_session(self, mut options: SessionBuilder) -> Result<Session> {
|
pub fn into_session(self, options: &SessionBuilder) -> Result<Session> {
|
||||||
let mut session_ptr = ptr::null_mut();
|
let mut session_ptr = ptr::null_mut();
|
||||||
ortsys![@editor:
|
ortsys![@editor:
|
||||||
unsafe CreateSessionFromModel(
|
unsafe CreateSessionFromModel(
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ fn test_identity_graph() -> Result<()> {
|
|||||||
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
|
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
|
||||||
model.add_graph(graph)?;
|
model.add_graph(graph)?;
|
||||||
|
|
||||||
let mut session = model.into_session(SessionBuilder::new()?)?;
|
let mut session = model.into_session(&SessionBuilder::new()?)?;
|
||||||
let output = session
|
let output = session
|
||||||
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![1.0f32; 5]))?])?
|
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![1.0f32; 5]))?])?
|
||||||
.remove("output")
|
.remove("output")
|
||||||
@@ -76,7 +76,7 @@ fn test_mul_graph() -> Result<()> {
|
|||||||
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
|
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
|
||||||
model.add_graph(graph)?;
|
model.add_graph(graph)?;
|
||||||
|
|
||||||
let mut session = model.into_session(SessionBuilder::new()?)?;
|
let mut session = model.into_session(&SessionBuilder::new()?)?;
|
||||||
let output = session
|
let output = session
|
||||||
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![2.0f32; 5]))?])?
|
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![2.0f32; 5]))?])?
|
||||||
.remove("output")
|
.remove("output")
|
||||||
|
|||||||
@@ -101,9 +101,9 @@ impl Environment {
|
|||||||
/// Sets the global log level.
|
/// Sets the global log level.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
|
/// # use ort::{environment::Environment, logging::LogLevel};
|
||||||
/// # fn main() -> ort::Result<()> {
|
/// # fn main() -> ort::Result<()> {
|
||||||
/// # use ort::logging::LogLevel;
|
/// let env = Environment::current()?;
|
||||||
/// let env = ort::environment::get_environment()?;
|
|
||||||
///
|
///
|
||||||
/// env.set_log_level(LogLevel::Warning);
|
/// env.set_log_level(LogLevel::Warning);
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
@@ -130,8 +130,9 @@ impl Environment {
|
|||||||
/// no longer be needed.
|
/// no longer be needed.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
|
/// # use ort::environment::Environment;
|
||||||
/// # fn main() -> ort::Result<()> {
|
/// # fn main() -> ort::Result<()> {
|
||||||
/// let env = ort::environment::get_environment()?;
|
/// let env = Environment::current()?;
|
||||||
///
|
///
|
||||||
/// let _ = env.register_ep_library("CUDA", "/path/to/onnxruntime_providers_cuda.dll");
|
/// let _ = env.register_ep_library("CUDA", "/path/to/onnxruntime_providers_cuda.dll");
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ impl ExecutionProvider for ACL {
|
|||||||
|
|
||||||
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, enable_fast_math: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, enable_fast_math: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||||
return Ok(unsafe {
|
return Ok(unsafe {
|
||||||
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.fast_math.into()))
|
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.fast_math.into()))
|
||||||
}?);
|
}?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ impl ExecutionProvider for ArmNN {
|
|||||||
|
|
||||||
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||||
return Ok(unsafe {
|
return Ok(unsafe {
|
||||||
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()))
|
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()))
|
||||||
}?);
|
}?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -102,11 +102,11 @@ impl ExecutionProvider for DirectML {
|
|||||||
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
|
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
|
||||||
#[cfg(any(feature = "load-dynamic", feature = "directml"))]
|
#[cfg(any(feature = "load-dynamic", feature = "directml"))]
|
||||||
{
|
{
|
||||||
use crate::AsPointer;
|
use crate::{AsPointer, Error};
|
||||||
|
|
||||||
let api = get_dml_api()?;
|
let api = get_dml_api()?;
|
||||||
if let Some(device_id) = self.device_id {
|
if let Some(device_id) = self.device_id {
|
||||||
unsafe { crate::error::status_to_result((api.SessionOptionsAppendExecutionProvider_DML)(session_builder.ptr_mut(), device_id as _)) }?;
|
unsafe { Error::result_from_status((api.SessionOptionsAppendExecutionProvider_DML)(session_builder.ptr_mut(), device_id as _)) }?;
|
||||||
} else {
|
} else {
|
||||||
let device_options = ort_sys::OrtDmlDeviceOptions {
|
let device_options = ort_sys::OrtDmlDeviceOptions {
|
||||||
Filter: match self.device_filter {
|
Filter: match self.device_filter {
|
||||||
@@ -120,7 +120,7 @@ impl ExecutionProvider for DirectML {
|
|||||||
PerformancePreference::MinimumPower => ort_sys::OrtDmlPerformancePreference::MinimumPower
|
PerformancePreference::MinimumPower => ort_sys::OrtDmlPerformancePreference::MinimumPower
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
unsafe { crate::error::status_to_result((api.SessionOptionsAppendExecutionProvider_DML2)(session_builder.ptr_mut(), &device_options)) }?;
|
unsafe { Error::result_from_status((api.SessionOptionsAppendExecutionProvider_DML2)(session_builder.ptr_mut(), &device_options)) }?;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ impl ExecutionProvider for NNAPI {
|
|||||||
if self.cpu_only {
|
if self.cpu_only {
|
||||||
flags |= 0x008;
|
flags |= 0x008;
|
||||||
}
|
}
|
||||||
return Ok(unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags)) }?);
|
return Ok(unsafe { crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags)) }?);
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(RegisterError::MissingFeature)
|
Err(RegisterError::MissingFeature)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ impl ExecutionProvider for RKNPU {
|
|||||||
use crate::AsPointer;
|
use crate::AsPointer;
|
||||||
|
|
||||||
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_RKNPU(options: *mut ort_sys::OrtSessionOptions) -> ort_sys::OrtStatusPtr);
|
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_RKNPU(options: *mut ort_sys::OrtSessionOptions) -> ort_sys::OrtStatusPtr);
|
||||||
return Ok(unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_RKNPU(session_builder.ptr_mut())) }?);
|
return Ok(unsafe { crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_RKNPU(session_builder.ptr_mut())) }?);
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(RegisterError::MissingFeature)
|
Err(RegisterError::MissingFeature)
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ impl ExecutionProvider for TVM {
|
|||||||
}
|
}
|
||||||
let options_string = alloc::ffi::CString::new(option_string.join(",")).expect("invalid option string");
|
let options_string = alloc::ffi::CString::new(option_string.join(",")).expect("invalid option string");
|
||||||
return Ok(unsafe {
|
return Ok(unsafe {
|
||||||
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
|
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
|
||||||
}?);
|
}?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
203
src/error.rs
203
src/error.rs
@@ -3,11 +3,17 @@ use alloc::{
|
|||||||
format,
|
format,
|
||||||
string::{String, ToString}
|
string::{String, ToString}
|
||||||
};
|
};
|
||||||
use core::{convert::Infallible, error::Error as CoreError, ffi::c_char, fmt, ptr};
|
use core::{
|
||||||
|
convert::Infallible,
|
||||||
|
error::Error as CoreError,
|
||||||
|
ffi::c_char,
|
||||||
|
fmt,
|
||||||
|
ptr::{self, NonNull}
|
||||||
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
ortsys,
|
ortsys,
|
||||||
util::{char_p_to_string, with_cstr}
|
util::{char_p_to_string, cold, with_cstr}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Type alias for the `Result` type returned by `ort` functions.
|
/// Type alias for the `Result` type returned by `ort` functions.
|
||||||
@@ -27,61 +33,164 @@ impl<T> IntoStatus for Result<T, Error> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An error returned by any `ort` API.
|
struct ErrorInternal {
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Error {
|
|
||||||
code: ErrorCode,
|
code: ErrorCode,
|
||||||
msg: String
|
message: String,
|
||||||
|
cause: Option<Box<dyn CoreError + Send + Sync + 'static>>,
|
||||||
|
status_ptr: NonNull<ort_sys::OrtStatus>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error {
|
unsafe impl Send for ErrorInternal {}
|
||||||
/// Wrap a custom, user-provided error in an [`ort::Error`](Error)..
|
unsafe impl Sync for ErrorInternal {}
|
||||||
|
|
||||||
|
impl ErrorInternal {
|
||||||
|
#[cold]
|
||||||
|
pub(crate) unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtStatus>) -> Self {
|
||||||
|
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(ptr.as_ptr())]);
|
||||||
|
let raw: *const c_char = ortsys![unsafe GetErrorMessage(ptr.as_ptr())];
|
||||||
|
match char_p_to_string(raw) {
|
||||||
|
Ok(message) => ErrorInternal {
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
cause: None,
|
||||||
|
status_ptr: ptr
|
||||||
|
},
|
||||||
|
Err(err) => ErrorInternal {
|
||||||
|
code,
|
||||||
|
message: format!("(failed to convert UTF-8: {err})"),
|
||||||
|
cause: None,
|
||||||
|
status_ptr: ptr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ErrorInternal {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
ortsys![unsafe ReleaseStatus(self.status_ptr.as_ptr())];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error returned by any `ort` API.
|
||||||
|
pub struct Error<R = ()> {
|
||||||
|
recover: R,
|
||||||
|
inner: Box<ErrorInternal>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> fmt::Debug for Error<R> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("Error")
|
||||||
|
.field("code", &self.inner.code)
|
||||||
|
.field("message", &self.message())
|
||||||
|
.field("ptr", &self.inner.status_ptr.as_ptr())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error<()> {
|
||||||
|
/// Converts an [`ort_sys::OrtStatusPtr`] to a [`Result`].
|
||||||
|
///
|
||||||
|
/// This takes ownership of the status pointer.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// `ptr` must be a valid `OrtStatusPtr` returned from an `ort-sys` API.
|
||||||
|
#[inline]
|
||||||
|
pub unsafe fn result_from_status(ptr: ort_sys::OrtStatusPtr) -> Result<(), Self> {
|
||||||
|
match NonNull::new(ptr.0) {
|
||||||
|
None => Ok(()),
|
||||||
|
Some(ptr) => {
|
||||||
|
cold();
|
||||||
|
|
||||||
|
Err(Self {
|
||||||
|
recover: (),
|
||||||
|
inner: Box::new(unsafe { ErrorInternal::from_ptr(ptr) })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap a custom, user-provided error in an [`ort::Error`](Error).
|
||||||
///
|
///
|
||||||
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
|
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
|
||||||
/// related operation fails.
|
/// related operation fails.
|
||||||
pub fn wrap<T: CoreError + Send + Sync + 'static>(err: T) -> Self {
|
pub fn wrap<T: CoreError + Send + Sync + 'static>(err: T) -> Self {
|
||||||
Error {
|
Self::new_internal(ErrorCode::GenericFailure, err.to_string(), Some(Box::new(err)))
|
||||||
code: ErrorCode::GenericFailure,
|
|
||||||
msg: err.to_string()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a custom [`Error`] with the given message.
|
/// Creates a custom [`Error`] with the given message.
|
||||||
pub fn new(msg: impl Into<String>) -> Self {
|
pub fn new(msg: impl Into<String>) -> Self {
|
||||||
Error {
|
Self::new_internal(ErrorCode::GenericFailure, msg, None)
|
||||||
code: ErrorCode::GenericFailure,
|
|
||||||
msg: msg.into()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a custom [`Error`] with the given [`ErrorCode`] and message.
|
/// Creates a custom [`Error`] with the given [`ErrorCode`] and message.
|
||||||
pub fn new_with_code(code: ErrorCode, msg: impl Into<String>) -> Self {
|
pub fn new_with_code(code: ErrorCode, msg: impl Into<String>) -> Self {
|
||||||
Error { code, msg: msg.into() }
|
Self::new_internal(code, msg, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_internal(code: ErrorCode, message: impl Into<String>, cause: Option<Box<dyn CoreError + Send + Sync + 'static>>) -> Self {
|
||||||
|
let message = message.into();
|
||||||
|
let ptr = with_cstr(message.as_bytes(), &|message| Ok(ortsys![unsafe CreateStatus(code.into(), message.as_ptr())])).expect("invalid error message");
|
||||||
|
Self {
|
||||||
|
recover: (),
|
||||||
|
inner: Box::new(ErrorInternal {
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
cause,
|
||||||
|
status_ptr: unsafe { NonNull::new_unchecked(ptr.0) }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn with_recover<R>(self, recover: R) -> Error<R> {
|
||||||
|
Error { recover, inner: self.inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> Error<R> {
|
||||||
pub fn code(&self) -> ErrorCode {
|
pub fn code(&self) -> ErrorCode {
|
||||||
self.code
|
self.inner.code
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn message(&self) -> &str {
|
pub fn message(&self) -> &str {
|
||||||
self.msg.as_str()
|
self.inner.message.as_str()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for Error {
|
impl<R: Sized> Error<R> {
|
||||||
|
/// Recovers from this error.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||||
|
/// # fn main() -> ort::Result<()> {
|
||||||
|
/// let session = Session::builder()?
|
||||||
|
/// .with_optimization_level(GraphOptimizationLevel::All)
|
||||||
|
/// // Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
|
||||||
|
/// .unwrap_or_else(|e| e.recover())
|
||||||
|
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
pub fn recover(self) -> R {
|
||||||
|
self.recover
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> fmt::Display for Error<R> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.write_str(&self.msg)
|
f.write_str(&self.inner.message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CoreError for Error {}
|
impl<R> CoreError for Error<R> {
|
||||||
|
fn source(&self) -> Option<&(dyn CoreError + 'static)> {
|
||||||
|
self.inner.cause.as_ref().map(|x| &**x as &dyn CoreError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Box<dyn CoreError + Send + Sync + 'static>> for Error {
|
impl From<Box<dyn CoreError + Send + Sync + 'static>> for Error {
|
||||||
fn from(err: Box<dyn CoreError + Send + Sync + 'static>) -> Self {
|
fn from(err: Box<dyn CoreError + Send + Sync + 'static>) -> Self {
|
||||||
Error {
|
Error::new_internal(ErrorCode::GenericFailure, err.to_string(), Some(err))
|
||||||
code: ErrorCode::GenericFailure,
|
|
||||||
msg: err.to_string()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +230,12 @@ impl From<alloc::ffi::IntoStringError> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Error<crate::session::builder::SessionBuilder>> for Error<()> {
|
||||||
|
fn from(err: Error<crate::session::builder::SessionBuilder>) -> Self {
|
||||||
|
Self { recover: (), inner: err.inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub enum ErrorCode {
|
pub enum ErrorCode {
|
||||||
@@ -177,39 +292,3 @@ impl From<ErrorCode> for ort_sys::OrtErrorCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts an [`ort_sys::OrtStatusPtr`] to a [`Result`].
|
|
||||||
///
|
|
||||||
/// **Note that this frees `status`!**
|
|
||||||
///
|
|
||||||
/// # Safety
|
|
||||||
/// The value contained in `status` must be a valid [`ort_sys::OrtStatus`] pointer, or a null pointer (in which case the
|
|
||||||
/// result will be `Ok`).
|
|
||||||
#[inline]
|
|
||||||
pub unsafe fn status_to_result(status: ort_sys::OrtStatusPtr) -> Result<(), Error> {
|
|
||||||
let status = status.0;
|
|
||||||
if status.is_null() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
#[cold]
|
|
||||||
fn status_to_error(status: *mut ort_sys::OrtStatus) -> Error {
|
|
||||||
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(status)]);
|
|
||||||
let raw: *const c_char = ortsys![unsafe GetErrorMessage(status)];
|
|
||||||
match char_p_to_string(raw) {
|
|
||||||
Ok(msg) => {
|
|
||||||
ortsys![unsafe ReleaseStatus(status)];
|
|
||||||
Error { code, msg }
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
ortsys![unsafe ReleaseStatus(status)];
|
|
||||||
Error {
|
|
||||||
code,
|
|
||||||
msg: format!("(failed to convert UTF-8: {err})")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(status_to_error(status))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -288,19 +288,19 @@ macro_rules! ortsys {
|
|||||||
unsafe { ($crate::api().$method)($($n),*) }
|
unsafe { ($crate::api().$method)($($n),*) }
|
||||||
};
|
};
|
||||||
(@ort: unsafe $method:ident($($n:expr),*) as Result) => {
|
(@ort: unsafe $method:ident($($n:expr),*) as Result) => {
|
||||||
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }
|
unsafe { $crate::error::Error::result_from_status(($crate::api().$method)($($n),+)) }
|
||||||
};
|
};
|
||||||
(@$api:ident: unsafe $method:ident($($n:expr),*)) => {
|
(@$api:ident: unsafe $method:ident($($n:expr),*)) => {
|
||||||
unsafe { ($crate::api::$api().unwrap().$method)($($n),+) }
|
unsafe { ($crate::api::$api().unwrap().$method)($($n),+) }
|
||||||
};
|
};
|
||||||
(@$api:ident: unsafe $method:ident($($n:expr),*)?) => {
|
(@$api:ident: unsafe $method:ident($($n:expr),*)?) => {
|
||||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })?
|
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })?
|
||||||
};
|
};
|
||||||
(@$api:ident: unsafe $method:ident($($n:expr),*)?; nonNull($($check:ident),+)$(;)?) => {
|
(@$api:ident: unsafe $method:ident($($n:expr),*)?; nonNull($($check:ident),+)$(;)?) => {
|
||||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })?;
|
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })?;
|
||||||
ortsys![@nonNull?; $($check),+];
|
ortsys![@nonNull?; $($check),+];
|
||||||
};
|
};
|
||||||
(@$api:ident: unsafe $method:ident($($n:expr),*) as Result) => {
|
(@$api:ident: unsafe $method:ident($($n:expr),*) as Result) => {
|
||||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })
|
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use std::sync::Mutex;
|
|||||||
use smallvec::SmallVec;
|
use smallvec::SmallVec;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
Error,
|
||||||
error::Result,
|
error::Result,
|
||||||
session::{SessionOutputs, SharedSessionInner, UntypedRunOptions},
|
session::{SessionOutputs, SharedSessionInner, UntypedRunOptions},
|
||||||
util::{STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS},
|
util::{STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS},
|
||||||
@@ -121,7 +122,7 @@ pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mu
|
|||||||
|
|
||||||
crate::logging::drop!(AsyncInferenceContext, user_data);
|
crate::logging::drop!(AsyncInferenceContext, user_data);
|
||||||
|
|
||||||
if let Err(e) = unsafe { crate::error::status_to_result(status) } {
|
if let Err(e) = unsafe { Error::result_from_status(status) } {
|
||||||
ctx.inner.emplace_value(Err(e));
|
ctx.inner.emplace_value(Err(e));
|
||||||
ctx.inner.wake();
|
ctx.inner.wake();
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -35,90 +35,10 @@ impl SessionBuilder {
|
|||||||
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
||||||
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))]
|
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))]
|
||||||
pub fn commit_from_url(&mut self, model_url: impl AsRef<str>) -> Result<Session> {
|
pub fn commit_from_url(&mut self, model_url: impl AsRef<str>) -> Result<Session> {
|
||||||
let downloaded_path = SessionBuilder::download(model_url.as_ref())?;
|
let downloaded_path = download_model(model_url.as_ref())?;
|
||||||
self.commit_from_file(downloaded_path)
|
self.commit_from_file(downloaded_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
|
||||||
fn download(url: &str) -> Result<PathBuf> {
|
|
||||||
use ureq::{
|
|
||||||
config::Config,
|
|
||||||
tls::{RootCerts, TlsConfig, TlsProvider}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut download_dir = ort_sys::internal::dirs::cache_dir()
|
|
||||||
.expect("could not determine cache directory")
|
|
||||||
.join("models");
|
|
||||||
if std::fs::create_dir_all(&download_dir).is_err() {
|
|
||||||
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
|
|
||||||
}
|
|
||||||
|
|
||||||
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
|
|
||||||
let _ = write!(&mut s, "{:02x}", b);
|
|
||||||
s
|
|
||||||
});
|
|
||||||
let model_filepath = download_dir.join(&model_filename);
|
|
||||||
if model_filepath.exists() {
|
|
||||||
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
|
|
||||||
Ok(model_filepath)
|
|
||||||
} else {
|
|
||||||
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
|
|
||||||
|
|
||||||
let agent = Config::builder()
|
|
||||||
.tls_config(
|
|
||||||
TlsConfig::builder()
|
|
||||||
.root_certs(RootCerts::WebPki)
|
|
||||||
.provider(if cfg!(any(feature = "tls-rustls", feature = "tls-rustls-no-provider")) {
|
|
||||||
TlsProvider::Rustls
|
|
||||||
} else if cfg!(any(feature = "tls-native", feature = "tls-native-vendored")) {
|
|
||||||
TlsProvider::NativeTls
|
|
||||||
} else {
|
|
||||||
return Err(Error::new(
|
|
||||||
"No TLS provider configured. When using `fetch-models` with HTTPS URLs, a `tls-*` feature must be enabled."
|
|
||||||
));
|
|
||||||
})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
.new_agent();
|
|
||||||
|
|
||||||
let resp = agent.get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;
|
|
||||||
|
|
||||||
let len = resp
|
|
||||||
.headers()
|
|
||||||
.get("Content-Length")
|
|
||||||
.and_then(|h| h.to_str().ok())
|
|
||||||
.and_then(|s| s.parse::<usize>().ok())
|
|
||||||
.expect("Missing Content-Length header");
|
|
||||||
crate::info!(len, "Downloading {} bytes", len);
|
|
||||||
|
|
||||||
let mut reader = resp.into_body().into_with_config().limit(u64::MAX).reader();
|
|
||||||
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
|
|
||||||
|
|
||||||
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
|
|
||||||
let mut writer = std::io::BufWriter::new(f);
|
|
||||||
|
|
||||||
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
|
|
||||||
if bytes_io_count != len as u64 {
|
|
||||||
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(writer);
|
|
||||||
|
|
||||||
match std::fs::rename(&temp_filepath, &model_filepath) {
|
|
||||||
Ok(()) => Ok(model_filepath),
|
|
||||||
Err(e) => {
|
|
||||||
if model_filepath.exists() {
|
|
||||||
let _ = std::fs::remove_file(temp_filepath);
|
|
||||||
Ok(model_filepath)
|
|
||||||
} else {
|
|
||||||
Err(Error::new(format!("Failed to download model: {e}")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Loads an ONNX model from a file and builds the session.
|
/// Loads an ONNX model from a file and builds the session.
|
||||||
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
|
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||||
@@ -331,3 +251,81 @@ impl SessionBuilder {
|
|||||||
EditableSession::new(session_ptr, self)
|
EditableSession::new(session_ptr, self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
||||||
|
fn download_model(url: &str) -> Result<PathBuf> {
|
||||||
|
use ureq::{
|
||||||
|
config::Config,
|
||||||
|
tls::{RootCerts, TlsConfig, TlsProvider}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut download_dir = ort_sys::internal::dirs::cache_dir()
|
||||||
|
.expect("could not determine cache directory")
|
||||||
|
.join("models");
|
||||||
|
if std::fs::create_dir_all(&download_dir).is_err() {
|
||||||
|
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
|
||||||
|
let _ = write!(&mut s, "{:02x}", b);
|
||||||
|
s
|
||||||
|
});
|
||||||
|
let model_filepath = download_dir.join(&model_filename);
|
||||||
|
if model_filepath.exists() {
|
||||||
|
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
|
||||||
|
Ok(model_filepath)
|
||||||
|
} else {
|
||||||
|
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
|
||||||
|
|
||||||
|
let agent = Config::builder()
|
||||||
|
.tls_config(
|
||||||
|
TlsConfig::builder()
|
||||||
|
.root_certs(RootCerts::WebPki)
|
||||||
|
.provider(if cfg!(any(feature = "tls-rustls", feature = "tls-rustls-no-provider")) {
|
||||||
|
TlsProvider::Rustls
|
||||||
|
} else if cfg!(any(feature = "tls-native", feature = "tls-native-vendored")) {
|
||||||
|
TlsProvider::NativeTls
|
||||||
|
} else {
|
||||||
|
return Err(Error::new("No TLS provider configured. When using `fetch-models` with HTTPS URLs, a `tls-*` feature must be enabled."));
|
||||||
|
})
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.new_agent();
|
||||||
|
|
||||||
|
let resp = agent.get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;
|
||||||
|
|
||||||
|
let len = resp
|
||||||
|
.headers()
|
||||||
|
.get("Content-Length")
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
|
.expect("Missing Content-Length header");
|
||||||
|
crate::info!(len, "Downloading {} bytes", len);
|
||||||
|
|
||||||
|
let mut reader = resp.into_body().into_with_config().limit(u64::MAX).reader();
|
||||||
|
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
|
||||||
|
|
||||||
|
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
|
||||||
|
let mut writer = std::io::BufWriter::new(f);
|
||||||
|
|
||||||
|
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
|
||||||
|
if bytes_io_count != len as u64 {
|
||||||
|
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(writer);
|
||||||
|
|
||||||
|
match std::fs::rename(&temp_filepath, &model_filepath) {
|
||||||
|
Ok(()) => Ok(model_filepath),
|
||||||
|
Err(e) => {
|
||||||
|
if model_filepath.exists() {
|
||||||
|
let _ = std::fs::remove_file(temp_filepath);
|
||||||
|
Ok(model_filepath)
|
||||||
|
} else {
|
||||||
|
Err(Error::new(format!("Failed to download model: {e}")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
use super::SessionBuilder;
|
use super::{BuilderResult, SessionBuilder};
|
||||||
use crate::Result;
|
|
||||||
|
|
||||||
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
|
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
|
||||||
|
|
||||||
@@ -7,17 +6,21 @@ impl SessionBuilder {
|
|||||||
/// Enable/disable the usage of prepacking.
|
/// Enable/disable the usage of prepacking.
|
||||||
///
|
///
|
||||||
/// This option is **enabled** by default.
|
/// This option is **enabled** by default.
|
||||||
pub fn with_prepacking(mut self, enable: bool) -> Result<Self> {
|
pub fn with_prepacking(mut self, enable: bool) -> BuilderResult {
|
||||||
self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?;
|
match self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" }) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Use allocators from the registered environment.
|
/// Use allocators from the registered environment.
|
||||||
///
|
///
|
||||||
/// This option is **disabled** by default.
|
/// This option is **disabled** by default.
|
||||||
pub fn with_env_allocators(mut self) -> Result<Self> {
|
pub fn with_env_allocators(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("session.use_env_allocators", "1")?;
|
match self.add_config_entry("session.use_env_allocators", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Disables subnormal floats by enabling the denormals-are-zero and flush-to-zero flags for all threads in the
|
/// Disables subnormal floats by enabling the denormals-are-zero and flush-to-zero flags for all threads in the
|
||||||
@@ -28,94 +31,137 @@ impl SessionBuilder {
|
|||||||
/// giving faster & more consistent performance, but lower accuracy (in cases where subnormals are involved).
|
/// giving faster & more consistent performance, but lower accuracy (in cases where subnormals are involved).
|
||||||
///
|
///
|
||||||
/// This option is **disabled** by default, as it may hurt model accuracy.
|
/// This option is **disabled** by default, as it may hurt model accuracy.
|
||||||
pub fn with_flush_to_zero(mut self) -> Result<Self> {
|
pub fn with_flush_to_zero(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("session.set_denormal_as_zero", "1")?;
|
match self.add_config_entry("session.set_denormal_as_zero", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable/disable fusion for quantized models in QDQ (`QuantizeLinear`/`DequantizeLinear`) format.
|
/// Enable/disable fusion for quantized models in QDQ (`QuantizeLinear`/`DequantizeLinear`) format.
|
||||||
///
|
///
|
||||||
/// This option is **enabled** by default for all EPs except DirectML.
|
/// This option is **enabled** by default for all EPs except DirectML.
|
||||||
pub fn with_quant_qdq(mut self, enable: bool) -> Result<Self> {
|
pub fn with_quant_qdq(mut self, enable: bool) -> BuilderResult {
|
||||||
self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?;
|
match self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" }) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable/disable the optimization step removing double QDQ nodes.
|
/// Enable/disable the optimization step removing double QDQ nodes.
|
||||||
///
|
///
|
||||||
/// This option is **enabled** by default.
|
/// This option is **enabled** by default.
|
||||||
pub fn with_double_qdq_remover(mut self, enable: bool) -> Result<Self> {
|
pub fn with_double_qdq_remover(mut self, enable: bool) -> BuilderResult {
|
||||||
self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?;
|
match self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" }) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed.
|
/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed.
|
||||||
///
|
///
|
||||||
/// This option is **disabled** by default.
|
/// This option is **disabled** by default.
|
||||||
pub fn with_qdq_cleanup(mut self) -> Result<Self> {
|
pub fn with_qdq_cleanup(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?;
|
match self.add_config_entry("session.enable_quant_qdq_cleanup", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable fast tanh-based GELU approximation (like PyTorch's `nn.GELU(approximate='tanh')`).
|
/// Enable fast tanh-based GELU approximation (like PyTorch's `nn.GELU(approximate='tanh')`).
|
||||||
///
|
///
|
||||||
/// This option is **disabled** by default, as it may impact results.
|
/// This option is **disabled** by default, as it may impact results.
|
||||||
pub fn with_approximate_gelu(mut self) -> Result<Self> {
|
pub fn with_approximate_gelu(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("optimization.enable_gelu_approximation", "1")?;
|
match self.add_config_entry("optimization.enable_gelu_approximation", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable the `Cast` chain elimination optimization.
|
/// Enable the `Cast` chain elimination optimization.
|
||||||
///
|
///
|
||||||
/// This option is **disabled** by default, as it may impact results.
|
/// This option is **disabled** by default, as it may impact results.
|
||||||
pub fn with_cast_chain_elimination(mut self) -> Result<Self> {
|
pub fn with_cast_chain_elimination(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("optimization.enable_cast_chain_elimination", "1")?;
|
match self.add_config_entry("optimization.enable_cast_chain_elimination", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable/disable ahead-of-time function inlining.
|
/// Enable/disable ahead-of-time function inlining.
|
||||||
///
|
///
|
||||||
/// This option is **enabled** by default.
|
/// This option is **enabled** by default.
|
||||||
pub fn with_aot_inlining(mut self, enable: bool) -> Result<Self> {
|
pub fn with_aot_inlining(mut self, enable: bool) -> BuilderResult {
|
||||||
self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?;
|
match self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" }) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Accepts a comma-separated list of optimizers to disable.
|
/// Accepts a comma-separated list of optimizers to disable.
|
||||||
pub fn with_disabled_optimizers(mut self, optimizers: impl AsRef<str>) -> Result<Self> {
|
pub fn with_disabled_optimizers(mut self, optimizers: impl AsRef<str>) -> BuilderResult {
|
||||||
self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?;
|
match self.add_config_entry("optimization.disable_specified_optimizers", optimizers) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable using the device allocator for allocating initialized tensor memory, potentially bypassing arena
|
/// Enable using the device allocator for allocating initialized tensor memory, potentially bypassing arena
|
||||||
/// allocators.
|
/// allocators.
|
||||||
///
|
///
|
||||||
/// This option is **disabled** by default.
|
/// This option is **disabled** by default.
|
||||||
pub fn with_device_allocated_initializers(mut self) -> Result<Self> {
|
pub fn with_device_allocated_initializers(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("session.use_device_allocator_for_initializers", "1")?;
|
match self.add_config_entry("session.use_device_allocator_for_initializers", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable/disable allowing the inter-op threads to spin for a short period before blocking.
|
/// Enable/disable allowing the inter-op threads to spin for a short period before blocking.
|
||||||
///
|
///
|
||||||
/// This option is **enabled** by defualt.
|
/// This option is **enabled** by defualt.
|
||||||
pub fn with_inter_op_spinning(mut self, enable: bool) -> Result<Self> {
|
pub fn with_inter_op_spinning(mut self, enable: bool) -> BuilderResult {
|
||||||
self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?;
|
match self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" }) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable/disable allowing the intra-op threads to spin for a short period before blocking.
|
/// Enable/disable allowing the intra-op threads to spin for a short period before blocking.
|
||||||
///
|
///
|
||||||
/// This option is **enabled** by defualt.
|
/// This option is **enabled** by defualt.
|
||||||
pub fn with_intra_op_spinning(mut self, enable: bool) -> Result<Self> {
|
pub fn with_intra_op_spinning(mut self, enable: bool) -> BuilderResult {
|
||||||
self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?;
|
match self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" }) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Disables falling back to the CPU for operations not supported by any other EP.
|
/// Disables falling back to the CPU for operations not supported by any other EP.
|
||||||
/// Models with graphs that cannot be placed entirely on the EP(s) will fail to commit.
|
/// Models with graphs that cannot be placed entirely on the EP(s) will fail to commit.
|
||||||
pub fn with_disable_cpu_fallback(mut self) -> Result<Self> {
|
pub fn with_disable_cpu_fallback(mut self) -> BuilderResult {
|
||||||
self.add_config_entry("session.disable_cpu_ep_fallback", "1")?;
|
match self.add_config_entry("session.disable_cpu_ep_fallback", "1") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Uses slower U8U8 matrix multiplication in place of U8S8 matrix multiplication that could potentially overflow on
|
||||||
|
/// x86-64 platforms without the VNNI extension.
|
||||||
|
///
|
||||||
|
/// This should only be enabled if you encounter overflow issues with quantized models.
|
||||||
|
pub fn with_precise_qmm(mut self) -> BuilderResult {
|
||||||
|
match self.add_config_entry("session.x64quantprecision", "1") {
|
||||||
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enables dynamic thread block sizing with the given base block size.
|
||||||
|
pub fn with_dynamic_block_base(mut self, size: u32) -> BuilderResult {
|
||||||
|
match self.add_config_entry("session.dynamic_block_base", size.to_string()) {
|
||||||
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,14 +7,13 @@ use core::{
|
|||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use super::SessionBuilder;
|
use super::{BuilderResult, SessionBuilder};
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
use crate::util::path_to_os_char;
|
use crate::util::path_to_os_char;
|
||||||
use crate::{
|
use crate::{
|
||||||
AsPointer, Error, ErrorCode,
|
AsPointer, Error, ErrorCode,
|
||||||
environment::{self, ThreadManager},
|
environment::{self, ThreadManager},
|
||||||
ep::{ExecutionProviderDispatch, apply_execution_providers},
|
ep::{ExecutionProviderDispatch, apply_execution_providers},
|
||||||
error::Result,
|
|
||||||
logging::{LogLevel, LoggerFunction},
|
logging::{LogLevel, LoggerFunction},
|
||||||
memory::MemoryInfo,
|
memory::MemoryInfo,
|
||||||
operator::OperatorDomain,
|
operator::OperatorDomain,
|
||||||
@@ -36,9 +35,11 @@ impl SessionBuilder {
|
|||||||
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
|
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
|
||||||
/// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec`
|
/// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec`
|
||||||
/// of [`ExecutionProviderDispatch`]es.
|
/// of [`ExecutionProviderDispatch`]es.
|
||||||
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result<Self> {
|
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> BuilderResult {
|
||||||
apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?;
|
match apply_execution_providers(&mut self, execution_providers.as_ref(), "session options") {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configure the session to use a number of threads to parallelize the execution within nodes. If ONNX Runtime was
|
/// Configure the session to use a number of threads to parallelize the execution within nodes. If ONNX Runtime was
|
||||||
@@ -48,9 +49,11 @@ impl SessionBuilder {
|
|||||||
///
|
///
|
||||||
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
|
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
|
||||||
/// [`SessionBuilder::with_inter_threads()`].
|
/// [`SessionBuilder::with_inter_threads()`].
|
||||||
pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
|
pub fn with_intra_threads(mut self, num_threads: usize) -> BuilderResult {
|
||||||
ortsys![unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
match ortsys![@ort: unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configure the session to use a number of threads to parallelize the execution of the graph. If nodes can be run
|
/// Configure the session to use a number of threads to parallelize the execution of the graph. If nodes can be run
|
||||||
@@ -60,9 +63,11 @@ impl SessionBuilder {
|
|||||||
///
|
///
|
||||||
/// For configuring the number of threads used to parallelize the execution within nodes, see
|
/// For configuring the number of threads used to parallelize the execution within nodes, see
|
||||||
/// [`SessionBuilder::with_intra_threads()`].
|
/// [`SessionBuilder::with_intra_threads()`].
|
||||||
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
|
pub fn with_inter_threads(mut self, num_threads: usize) -> BuilderResult {
|
||||||
ortsys![unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
match ortsys![@ort: unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable/disable the parallel execution mode for this session. By default, this is disabled.
|
/// Enable/disable the parallel execution mode for this session. By default, this is disabled.
|
||||||
@@ -70,21 +75,25 @@ impl SessionBuilder {
|
|||||||
/// Parallel execution can improve performance for models with many branches, at the cost of higher memory usage.
|
/// Parallel execution can improve performance for models with many branches, at the cost of higher memory usage.
|
||||||
/// You can configure the amount of threads used to parallelize the execution of the graph via
|
/// You can configure the amount of threads used to parallelize the execution of the graph via
|
||||||
/// [`SessionBuilder::with_inter_threads()`].
|
/// [`SessionBuilder::with_inter_threads()`].
|
||||||
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> Result<Self> {
|
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> BuilderResult {
|
||||||
let execution_mode = if parallel_execution {
|
let execution_mode = if parallel_execution {
|
||||||
ort_sys::ExecutionMode::ORT_PARALLEL
|
ort_sys::ExecutionMode::ORT_PARALLEL
|
||||||
} else {
|
} else {
|
||||||
ort_sys::ExecutionMode::ORT_SEQUENTIAL
|
ort_sys::ExecutionMode::ORT_SEQUENTIAL
|
||||||
};
|
};
|
||||||
ortsys![unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode)?];
|
match ortsys![@ort: unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the session's optimization level. See [`GraphOptimizationLevel`] for more information on the different
|
/// Set the session's optimization level. See [`GraphOptimizationLevel`] for more information on the different
|
||||||
/// optimization levels.
|
/// optimization levels.
|
||||||
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Result<Self> {
|
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> BuilderResult {
|
||||||
ortsys![unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into())?];
|
match ortsys![@ort: unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into()) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// After performing optimization (configurable with [`SessionBuilder::with_optimization_level`]), serializes the
|
/// After performing optimization (configurable with [`SessionBuilder::with_optimization_level`]), serializes the
|
||||||
@@ -93,10 +102,12 @@ impl SessionBuilder {
|
|||||||
/// Note that the file will only be created after the model is committed.
|
/// Note that the file will only be created after the model is committed.
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||||
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> Result<Self> {
|
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> BuilderResult {
|
||||||
let path = crate::util::path_to_os_char(path);
|
let path = crate::util::path_to_os_char(path);
|
||||||
ortsys![unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr())?];
|
match ortsys![@ort: unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr()) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
|
/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
|
||||||
@@ -105,29 +116,37 @@ impl SessionBuilder {
|
|||||||
/// [`Session::end_profiling`]: crate::session::Session::end_profiling
|
/// [`Session::end_profiling`]: crate::session::Session::end_profiling
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||||
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> Result<Self> {
|
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> BuilderResult {
|
||||||
let profiling_file = crate::util::path_to_os_char(profiling_file);
|
let profiling_file = crate::util::path_to_os_char(profiling_file);
|
||||||
ortsys![unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr())?];
|
match ortsys![@ort: unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr()) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enables/disables memory pattern optimization. Disable it if the input size varies, i.e., dynamic batch
|
/// Enables/disables memory pattern optimization. Disable it if the input size varies, i.e., dynamic batch
|
||||||
pub fn with_memory_pattern(mut self, enable: bool) -> Result<Self> {
|
pub fn with_memory_pattern(mut self, enable: bool) -> BuilderResult {
|
||||||
if enable {
|
let result = if enable {
|
||||||
ortsys![unsafe EnableMemPattern(self.ptr_mut())?];
|
ortsys![@ort: unsafe EnableMemPattern(self.ptr_mut()) as Result]
|
||||||
} else {
|
} else {
|
||||||
ortsys![unsafe DisableMemPattern(self.ptr_mut())?];
|
ortsys![@ort: unsafe DisableMemPattern(self.ptr_mut()) as Result]
|
||||||
|
};
|
||||||
|
match result {
|
||||||
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
}
|
}
|
||||||
Ok(self)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configure this session to use a custom allocator, rather than the global default. This allocator is responsible
|
/// Configure this session to use a custom allocator, rather than the global default. This allocator is responsible
|
||||||
/// for allocating the *metadata* associated with values -- not the contents of the values themselves; that's
|
/// for allocating the *metadata* associated with values -- not the contents of the values themselves; that's
|
||||||
/// handled by the active execution providers. As such, only CPU-accessible allocators are allowed.
|
/// handled by the active execution providers. As such, only CPU-accessible allocators are allowed.
|
||||||
pub fn with_allocator(mut self, info: MemoryInfo) -> Result<Self> {
|
pub fn with_allocator(mut self, info: MemoryInfo) -> BuilderResult {
|
||||||
if !info.is_cpu_accessible() {
|
if !info.is_cpu_accessible() {
|
||||||
return Err(Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator"));
|
return Err(
|
||||||
|
Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator").with_recover(self)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.memory_info = Some(Arc::new(info));
|
self.memory_info = Some(Arc::new(info));
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
@@ -135,117 +154,145 @@ impl SessionBuilder {
|
|||||||
/// Registers a custom operator library at the given library path.
|
/// Registers a custom operator library at the given library path.
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||||
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> Result<Self> {
|
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> BuilderResult {
|
||||||
let path_cstr = path_to_os_char(lib_path);
|
let path_cstr = path_to_os_char(lib_path);
|
||||||
ortsys![unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr())?];
|
match ortsys![@ort: unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr()) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators.
|
/// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators.
|
||||||
pub fn with_extensions(mut self) -> Result<Self> {
|
pub fn with_extensions(mut self) -> BuilderResult {
|
||||||
ortsys![unsafe EnableOrtCustomOps(self.ptr_mut())?];
|
match ortsys![@ort: unsafe EnableOrtCustomOps(self.ptr_mut()) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> Result<Self> {
|
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> BuilderResult {
|
||||||
let domain: Arc<OperatorDomain> = domain.into();
|
let domain: Arc<OperatorDomain> = domain.into();
|
||||||
ortsys![unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut())?];
|
match ortsys![@ort: unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut()) as Result] {
|
||||||
self.operator_domains.push(domain);
|
Ok(()) => {
|
||||||
Ok(self)
|
self.operator_domains.push(domain);
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enables/disables deterministic computation.
|
/// Enables/disables deterministic computation.
|
||||||
///
|
///
|
||||||
/// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance.
|
/// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance.
|
||||||
/// Enabling deterministic compute will output reproducible results, but may come at a performance penalty.
|
/// Enabling deterministic compute will output reproducible results, but may come at a performance penalty.
|
||||||
pub fn with_deterministic_compute(mut self, enable: bool) -> Result<Self> {
|
pub fn with_deterministic_compute(mut self, enable: bool) -> BuilderResult {
|
||||||
ortsys![unsafe SetDeterministicCompute(self.ptr_mut(), enable)?];
|
match ortsys![@ort: unsafe SetDeterministicCompute(self.ptr_mut(), enable) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> Result<Self> {
|
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
let value: Arc<DynValue> = value.into();
|
let value: Arc<DynValue> = value.into();
|
||||||
with_cstr(name.as_ref().as_bytes(), &|name| {
|
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddInitializer(ptr, name.as_ptr(), value.ptr()) as Result]) {
|
||||||
ortsys![unsafe AddInitializer(ptr, name.as_ptr(), value.ptr())?];
|
Ok(()) => {
|
||||||
Ok(())
|
self.initializers.push(value);
|
||||||
})?;
|
Ok(self)
|
||||||
self.initializers.push(value);
|
}
|
||||||
Ok(self)
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> Result<Self> {
|
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
let value: Arc<DynValue> = value.into();
|
let value: Arc<DynValue> = value.into();
|
||||||
with_cstr(name.as_ref().as_bytes(), &|name| {
|
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1) as Result]) {
|
||||||
ortsys![unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1)?];
|
Ok(()) => {
|
||||||
Ok(())
|
self.initializers.push(value);
|
||||||
})?;
|
Ok(self)
|
||||||
self.initializers.push(value);
|
}
|
||||||
Ok(self)
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "std", feature = "api-18"))]
|
#[cfg(all(feature = "std", feature = "api-18"))]
|
||||||
#[cfg_attr(docsrs, doc(cfg(all(feature = "std", feature = "api-18"))))]
|
#[cfg_attr(docsrs, doc(cfg(all(feature = "std", feature = "api-18"))))]
|
||||||
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> Result<Self> {
|
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> BuilderResult {
|
||||||
let file_name = path_to_os_char(file_name);
|
let file_name = path_to_os_char(file_name);
|
||||||
let sizes = [buffer.len()];
|
let sizes = [buffer.len()];
|
||||||
ortsys![unsafe AddExternalInitializersFromMemory(self.ptr_mut(), &file_name.as_ptr(), &buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(), sizes.as_ptr(), 1)?];
|
match ortsys![@ort:
|
||||||
self.external_initializer_buffers.push(buffer);
|
unsafe AddExternalInitializersFromMemory(
|
||||||
Ok(self)
|
self.ptr_mut(),
|
||||||
|
&file_name.as_ptr(),
|
||||||
|
&buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(),
|
||||||
|
sizes.as_ptr(),
|
||||||
|
1
|
||||||
|
) as Result
|
||||||
|
] {
|
||||||
|
Ok(()) => {
|
||||||
|
self.external_initializer_buffers.push(buffer);
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_log_id(mut self, id: impl AsRef<str>) -> Result<Self> {
|
pub fn with_log_id(mut self, id: impl AsRef<str>) -> BuilderResult {
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
with_cstr(id.as_ref().as_bytes(), &|id| {
|
match with_cstr(id.as_ref().as_bytes(), &|id| ortsys![@ort: unsafe SetSessionLogId(ptr, id.as_ptr()) as Result]) {
|
||||||
ortsys![unsafe SetSessionLogId(ptr, id.as_ptr())?];
|
Ok(()) => Ok(self),
|
||||||
Ok(())
|
Err(e) => Err(e.with_recover(self))
|
||||||
})?;
|
}
|
||||||
Ok(self)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> Result<Self> {
|
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> BuilderResult {
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
with_cstr(name.as_ref().as_bytes(), &|name| {
|
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size) as Result]) {
|
||||||
ortsys![unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size)?];
|
Ok(()) => Ok(self),
|
||||||
Ok(())
|
Err(e) => Err(e.with_recover(self))
|
||||||
})?;
|
}
|
||||||
Ok(self)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> Result<Self> {
|
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> BuilderResult {
|
||||||
let ptr = self.ptr_mut();
|
let ptr = self.ptr_mut();
|
||||||
with_cstr(denotation.as_ref().as_bytes(), &|denotation| {
|
match with_cstr(denotation.as_ref().as_bytes(), &|denotation| ortsys![@ort: unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size) as Result])
|
||||||
ortsys![unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size)?];
|
{
|
||||||
Ok(())
|
Ok(()) => Ok(self),
|
||||||
})?;
|
Err(e) => Err(e.with_recover(self))
|
||||||
Ok(self)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> Result<Self> {
|
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> BuilderResult {
|
||||||
self.prepacked_weights = Some(weights.clone());
|
self.prepacked_weights = Some(weights.clone());
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configures this environment to use its own thread pool instead of defaulting to the
|
/// Configures this environment to use its own thread pool instead of defaulting to the
|
||||||
/// [`Environment`](crate::environment::Environment)'s global thread pool if one was defined.
|
/// [`Environment`](crate::environment::Environment)'s global thread pool if one was defined.
|
||||||
pub fn with_independent_thread_pool(mut self) -> Result<Self> {
|
pub fn with_independent_thread_pool(mut self) -> BuilderResult {
|
||||||
self.no_global_thread_pool = true;
|
self.no_global_thread_pool = true;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_no_environment_execution_providers(mut self) -> Result<Self> {
|
pub fn with_no_environment_execution_providers(mut self) -> BuilderResult {
|
||||||
self.no_env_eps = true;
|
self.no_env_eps = true;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> BuilderResult {
|
||||||
let manager = Arc::new(manager);
|
let manager = Arc::new(manager);
|
||||||
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
let ptr = self.ptr_mut();
|
||||||
ortsys![unsafe SessionOptionsSetCustomCreateThreadFn(self.ptr_mut(), Some(environment::thread_create::<T>))?];
|
match ortsys![@ort: unsafe SessionOptionsSetCustomThreadCreationOptions(ptr, (&*manager as *const T) as *mut c_void) as Result]
|
||||||
ortsys![unsafe SessionOptionsSetCustomJoinThreadFn(self.ptr_mut(), Some(environment::thread_join::<T>))?];
|
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomCreateThreadFn(ptr, Some(environment::thread_create::<T>)) as Result])
|
||||||
self.thread_manager = Some(manager as Arc<dyn Any>);
|
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomJoinThreadFn(ptr, Some(environment::thread_join::<T>)) as Result])
|
||||||
Ok(self)
|
{
|
||||||
|
Ok(()) => {
|
||||||
|
self.thread_manager = Some(manager as Arc<dyn Any>);
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configures this session to use a custom logger function.
|
/// Configures this session to use a custom logger function.
|
||||||
@@ -268,11 +315,15 @@ impl SessionBuilder {
|
|||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn with_logger(mut self, logger: LoggerFunction) -> Result<Self> {
|
pub fn with_logger(mut self, logger: LoggerFunction) -> BuilderResult {
|
||||||
let logger = Arc::new(logger);
|
let logger = Arc::new(logger);
|
||||||
ortsys![unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void)?];
|
match ortsys![@ort: unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void) as Result] {
|
||||||
self.logger = Some(logger);
|
Ok(()) => {
|
||||||
Ok(self)
|
self.logger = Some(logger);
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the severity level for messages logged by this session.
|
/// Sets the severity level for messages logged by this session.
|
||||||
@@ -281,15 +332,19 @@ impl SessionBuilder {
|
|||||||
/// precedence, i.e. if the application was initialized with `ort`'s log level set to `warn` via the `RUST_LOG`
|
/// precedence, i.e. if the application was initialized with `ort`'s log level set to `warn` via the `RUST_LOG`
|
||||||
/// environment variable or similar, setting a session's log severity level to `verbose` will still have it only
|
/// environment variable or similar, setting a session's log severity level to `verbose` will still have it only
|
||||||
/// log `warn` messages or higher.`
|
/// log `warn` messages or higher.`
|
||||||
pub fn with_log_level(mut self, level: LogLevel) -> Result<Self> {
|
pub fn with_log_level(mut self, level: LogLevel) -> BuilderResult {
|
||||||
ortsys![unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _)?];
|
match ortsys![@ort: unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Controls the level of verbosity for messages logged under [`LogLevel::Verbose`]; higher values = more verbose.
|
/// Controls the level of verbosity for messages logged under [`LogLevel::Verbose`]; higher values = more verbose.
|
||||||
pub fn with_log_verbosity(mut self, verbosity: c_int) -> Result<Self> {
|
pub fn with_log_verbosity(mut self, verbosity: c_int) -> BuilderResult {
|
||||||
ortsys![unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity)?];
|
match ortsys![@ort: unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy) based
|
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy) based
|
||||||
@@ -309,9 +364,11 @@ impl SessionBuilder {
|
|||||||
/// ```
|
/// ```
|
||||||
#[cfg(feature = "api-22")]
|
#[cfg(feature = "api-22")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||||
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> Result<Self> {
|
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> BuilderResult {
|
||||||
ortsys![unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into())?];
|
match ortsys![@ort: unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into()) as Result] {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,22 @@ mod impl_options;
|
|||||||
pub use self::editable::*;
|
pub use self::editable::*;
|
||||||
pub use self::impl_options::*;
|
pub use self::impl_options::*;
|
||||||
|
|
||||||
|
/// `Result` type returned by [`SessionBuilder`] methods.
|
||||||
|
///
|
||||||
|
/// This type supports [error recovery](Error::recover):
|
||||||
|
/// ```
|
||||||
|
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||||
|
/// # fn main() -> ort::Result<()> {
|
||||||
|
/// let session = Session::builder()?
|
||||||
|
/// .with_optimization_level(GraphOptimizationLevel::All)
|
||||||
|
/// // Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
|
||||||
|
/// .unwrap_or_else(|e| e.recover())
|
||||||
|
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
pub type BuilderResult = Result<SessionBuilder, Error<SessionBuilder>>;
|
||||||
|
|
||||||
/// Creates a session using the builder pattern.
|
/// Creates a session using the builder pattern.
|
||||||
///
|
///
|
||||||
/// Once configured, use the
|
/// Once configured, use the
|
||||||
@@ -146,7 +162,7 @@ impl SessionBuilder {
|
|||||||
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||||
/// # use std::{thread, time::Duration};
|
/// # use std::{thread, time::Duration};
|
||||||
/// # fn main() -> ort::Result<()> {
|
/// # fn main() -> ort::Result<()> {
|
||||||
/// let builder = Session::builder()?
|
/// let mut builder = Session::builder()?
|
||||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||||
/// .with_intra_threads(1)?;
|
/// .with_intra_threads(1)?;
|
||||||
///
|
///
|
||||||
@@ -168,9 +184,11 @@ impl SessionBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Adds a custom configuration entry to the session.
|
/// Adds a custom configuration entry to the session.
|
||||||
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
|
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> BuilderResult {
|
||||||
self.add_config_entry(key.as_ref(), value.as_ref())?;
|
match self.add_config_entry(key.as_ref(), value.as_ref()) {
|
||||||
Ok(self)
|
Ok(()) => Ok(self),
|
||||||
|
Err(e) => Err(e.with_recover(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +221,7 @@ impl LoadCanceler {
|
|||||||
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||||
/// # use std::{thread, time::Duration};
|
/// # use std::{thread, time::Duration};
|
||||||
/// # fn main() -> ort::Result<()> {
|
/// # fn main() -> ort::Result<()> {
|
||||||
/// let builder = Session::builder()?
|
/// let mut builder = Session::builder()?
|
||||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||||
/// .with_intra_threads(1)?;
|
/// .with_intra_threads(1)?;
|
||||||
///
|
///
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ use smallvec::SmallVec;
|
|||||||
use crate::{
|
use crate::{
|
||||||
AsPointer,
|
AsPointer,
|
||||||
environment::Environment,
|
environment::Environment,
|
||||||
error::{Error, ErrorCode, Result, status_to_result},
|
error::{Error, ErrorCode, Result},
|
||||||
memory::Allocator,
|
memory::Allocator,
|
||||||
ortsys,
|
ortsys,
|
||||||
util::{AllocatedString, STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr, with_cstr_ptr_array},
|
util::{AllocatedString, STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr, with_cstr_ptr_array},
|
||||||
@@ -708,7 +708,7 @@ pub(crate) mod io {
|
|||||||
) -> Result<usize> {
|
) -> Result<usize> {
|
||||||
let mut num_nodes = 0;
|
let mut num_nodes = 0;
|
||||||
let status = unsafe { f(session_ptr.as_ptr(), &mut num_nodes) };
|
let status = unsafe { f(session_ptr.as_ptr(), &mut num_nodes) };
|
||||||
unsafe { status_to_result(status) }?;
|
unsafe { Error::result_from_status(status) }?;
|
||||||
Ok(num_nodes)
|
Ok(num_nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -721,7 +721,7 @@ pub(crate) mod io {
|
|||||||
let mut name_ptr: *mut c_char = ptr::null_mut();
|
let mut name_ptr: *mut c_char = ptr::null_mut();
|
||||||
|
|
||||||
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_ptr) };
|
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_ptr) };
|
||||||
unsafe { status_to_result(status) }?;
|
unsafe { Error::result_from_status(status) }?;
|
||||||
if name_ptr.is_null() {
|
if name_ptr.is_null() {
|
||||||
crate::util::cold();
|
crate::util::cold();
|
||||||
return Err(crate::Error::new("expected `name_ptr` to not be null"));
|
return Err(crate::Error::new("expected `name_ptr` to not be null"));
|
||||||
@@ -738,7 +738,7 @@ pub(crate) mod io {
|
|||||||
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||||
|
|
||||||
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
|
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
|
||||||
unsafe { status_to_result(status) }?;
|
unsafe { Error::result_from_status(status) }?;
|
||||||
let Some(typeinfo_ptr) = NonNull::new(typeinfo_ptr) else {
|
let Some(typeinfo_ptr) = NonNull::new(typeinfo_ptr) else {
|
||||||
crate::util::cold();
|
crate::util::cold();
|
||||||
return Err(crate::Error::new("expected `typeinfo_ptr` to not be null"));
|
return Err(crate::Error::new("expected `typeinfo_ptr` to not be null"));
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ use super::{Checkpoint, Optimizer, training_api};
|
|||||||
use crate::{
|
use crate::{
|
||||||
AsPointer,
|
AsPointer,
|
||||||
environment::Environment,
|
environment::Environment,
|
||||||
error::{Result, status_to_result},
|
error::{Error, Result},
|
||||||
memory::Allocator,
|
memory::Allocator,
|
||||||
ortsys,
|
ortsys,
|
||||||
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
|
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
|
||||||
@@ -329,11 +329,11 @@ fn extract_io_names(
|
|||||||
) -> ort_sys::OrtStatusPtr
|
) -> ort_sys::OrtStatusPtr
|
||||||
) -> Result<Vec<String>> {
|
) -> Result<Vec<String>> {
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
unsafe { status_to_result(get_count(ptr.as_ptr(), &mut count)) }?;
|
unsafe { Error::result_from_status(get_count(ptr.as_ptr(), &mut count)) }?;
|
||||||
(0..count)
|
(0..count)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let mut name_bytes: *const c_char = ptr::null();
|
let mut name_bytes: *const c_char = ptr::null();
|
||||||
unsafe { status_to_result(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
|
unsafe { Error::result_from_status(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
|
||||||
let name = match char_p_to_string(name_bytes) {
|
let name = match char_p_to_string(name_bytes) {
|
||||||
Ok(name) => name,
|
Ok(name) => name,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user