mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: recover from SessionBuilder errors
This commit is contained in:
@@ -22,7 +22,7 @@
|
||||
//! # use std::path::PathBuf;
|
||||
//! # use ort::{compiler::ModelCompiler, session::Session, ep};
|
||||
//! # fn main() -> ort::Result<()> {
|
||||
//! let session_options = Session::builder()?.with_execution_providers([ep::CoreML::default()
|
||||
//! let mut session_options = Session::builder()?.with_execution_providers([ep::CoreML::default()
|
||||
//! .with_model_format(ep::coreml::ModelFormat::MLProgram)
|
||||
//! .build()])?;
|
||||
//!
|
||||
|
||||
@@ -263,7 +263,7 @@ impl Model {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn into_session(self, mut options: SessionBuilder) -> Result<Session> {
|
||||
pub fn into_session(self, options: &SessionBuilder) -> Result<Session> {
|
||||
let mut session_ptr = ptr::null_mut();
|
||||
ortsys![@editor:
|
||||
unsafe CreateSessionFromModel(
|
||||
|
||||
@@ -32,7 +32,7 @@ fn test_identity_graph() -> Result<()> {
|
||||
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
|
||||
model.add_graph(graph)?;
|
||||
|
||||
let mut session = model.into_session(SessionBuilder::new()?)?;
|
||||
let mut session = model.into_session(&SessionBuilder::new()?)?;
|
||||
let output = session
|
||||
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![1.0f32; 5]))?])?
|
||||
.remove("output")
|
||||
@@ -76,7 +76,7 @@ fn test_mul_graph() -> Result<()> {
|
||||
let mut model = Model::new([Opset::new(ONNX_DOMAIN, 22)?])?;
|
||||
model.add_graph(graph)?;
|
||||
|
||||
let mut session = model.into_session(SessionBuilder::new()?)?;
|
||||
let mut session = model.into_session(&SessionBuilder::new()?)?;
|
||||
let output = session
|
||||
.run(inputs![Tensor::<f32>::from_array((Shape::new([5]), vec![2.0f32; 5]))?])?
|
||||
.remove("output")
|
||||
|
||||
@@ -101,9 +101,9 @@ impl Environment {
|
||||
/// Sets the global log level.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{environment::Environment, logging::LogLevel};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # use ort::logging::LogLevel;
|
||||
/// let env = ort::environment::get_environment()?;
|
||||
/// let env = Environment::current()?;
|
||||
///
|
||||
/// env.set_log_level(LogLevel::Warning);
|
||||
/// # Ok(())
|
||||
@@ -130,8 +130,9 @@ impl Environment {
|
||||
/// no longer be needed.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::environment::Environment;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let env = ort::environment::get_environment()?;
|
||||
/// let env = Environment::current()?;
|
||||
///
|
||||
/// let _ = env.register_ep_library("CUDA", "/path/to/onnxruntime_providers_cuda.dll");
|
||||
/// # Ok(())
|
||||
|
||||
@@ -45,7 +45,7 @@ impl ExecutionProvider for ACL {
|
||||
|
||||
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, enable_fast_math: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||
return Ok(unsafe {
|
||||
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.fast_math.into()))
|
||||
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.fast_math.into()))
|
||||
}?);
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ impl ExecutionProvider for ArmNN {
|
||||
|
||||
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||
return Ok(unsafe {
|
||||
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()))
|
||||
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()))
|
||||
}?);
|
||||
}
|
||||
|
||||
|
||||
@@ -102,11 +102,11 @@ impl ExecutionProvider for DirectML {
|
||||
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
|
||||
#[cfg(any(feature = "load-dynamic", feature = "directml"))]
|
||||
{
|
||||
use crate::AsPointer;
|
||||
use crate::{AsPointer, Error};
|
||||
|
||||
let api = get_dml_api()?;
|
||||
if let Some(device_id) = self.device_id {
|
||||
unsafe { crate::error::status_to_result((api.SessionOptionsAppendExecutionProvider_DML)(session_builder.ptr_mut(), device_id as _)) }?;
|
||||
unsafe { Error::result_from_status((api.SessionOptionsAppendExecutionProvider_DML)(session_builder.ptr_mut(), device_id as _)) }?;
|
||||
} else {
|
||||
let device_options = ort_sys::OrtDmlDeviceOptions {
|
||||
Filter: match self.device_filter {
|
||||
@@ -120,7 +120,7 @@ impl ExecutionProvider for DirectML {
|
||||
PerformancePreference::MinimumPower => ort_sys::OrtDmlPerformancePreference::MinimumPower
|
||||
}
|
||||
};
|
||||
unsafe { crate::error::status_to_result((api.SessionOptionsAppendExecutionProvider_DML2)(session_builder.ptr_mut(), &device_options)) }?;
|
||||
unsafe { Error::result_from_status((api.SessionOptionsAppendExecutionProvider_DML2)(session_builder.ptr_mut(), &device_options)) }?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
|
||||
@@ -78,7 +78,7 @@ impl ExecutionProvider for NNAPI {
|
||||
if self.cpu_only {
|
||||
flags |= 0x008;
|
||||
}
|
||||
return Ok(unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags)) }?);
|
||||
return Ok(unsafe { crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags)) }?);
|
||||
}
|
||||
|
||||
Err(RegisterError::MissingFeature)
|
||||
|
||||
@@ -22,7 +22,7 @@ impl ExecutionProvider for RKNPU {
|
||||
use crate::AsPointer;
|
||||
|
||||
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_RKNPU(options: *mut ort_sys::OrtSessionOptions) -> ort_sys::OrtStatusPtr);
|
||||
return Ok(unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_RKNPU(session_builder.ptr_mut())) }?);
|
||||
return Ok(unsafe { crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_RKNPU(session_builder.ptr_mut())) }?);
|
||||
}
|
||||
|
||||
Err(RegisterError::MissingFeature)
|
||||
|
||||
@@ -100,7 +100,7 @@ impl ExecutionProvider for TVM {
|
||||
}
|
||||
let options_string = alloc::ffi::CString::new(option_string.join(",")).expect("invalid option string");
|
||||
return Ok(unsafe {
|
||||
crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
|
||||
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
|
||||
}?);
|
||||
}
|
||||
|
||||
|
||||
203
src/error.rs
203
src/error.rs
@@ -3,11 +3,17 @@ use alloc::{
|
||||
format,
|
||||
string::{String, ToString}
|
||||
};
|
||||
use core::{convert::Infallible, error::Error as CoreError, ffi::c_char, fmt, ptr};
|
||||
use core::{
|
||||
convert::Infallible,
|
||||
error::Error as CoreError,
|
||||
ffi::c_char,
|
||||
fmt,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ortsys,
|
||||
util::{char_p_to_string, with_cstr}
|
||||
util::{char_p_to_string, cold, with_cstr}
|
||||
};
|
||||
|
||||
/// Type alias for the `Result` type returned by `ort` functions.
|
||||
@@ -27,61 +33,164 @@ impl<T> IntoStatus for Result<T, Error> {
|
||||
}
|
||||
}
|
||||
|
||||
/// An error returned by any `ort` API.
|
||||
#[derive(Debug)]
|
||||
pub struct Error {
|
||||
struct ErrorInternal {
|
||||
code: ErrorCode,
|
||||
msg: String
|
||||
message: String,
|
||||
cause: Option<Box<dyn CoreError + Send + Sync + 'static>>,
|
||||
status_ptr: NonNull<ort_sys::OrtStatus>
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Wrap a custom, user-provided error in an [`ort::Error`](Error)..
|
||||
unsafe impl Send for ErrorInternal {}
|
||||
unsafe impl Sync for ErrorInternal {}
|
||||
|
||||
impl ErrorInternal {
|
||||
#[cold]
|
||||
pub(crate) unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtStatus>) -> Self {
|
||||
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(ptr.as_ptr())]);
|
||||
let raw: *const c_char = ortsys![unsafe GetErrorMessage(ptr.as_ptr())];
|
||||
match char_p_to_string(raw) {
|
||||
Ok(message) => ErrorInternal {
|
||||
code,
|
||||
message,
|
||||
cause: None,
|
||||
status_ptr: ptr
|
||||
},
|
||||
Err(err) => ErrorInternal {
|
||||
code,
|
||||
message: format!("(failed to convert UTF-8: {err})"),
|
||||
cause: None,
|
||||
status_ptr: ptr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ErrorInternal {
|
||||
fn drop(&mut self) {
|
||||
ortsys![unsafe ReleaseStatus(self.status_ptr.as_ptr())];
|
||||
}
|
||||
}
|
||||
|
||||
/// An error returned by any `ort` API.
|
||||
pub struct Error<R = ()> {
|
||||
recover: R,
|
||||
inner: Box<ErrorInternal>
|
||||
}
|
||||
|
||||
impl<R> fmt::Debug for Error<R> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Error")
|
||||
.field("code", &self.inner.code)
|
||||
.field("message", &self.message())
|
||||
.field("ptr", &self.inner.status_ptr.as_ptr())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Error<()> {
|
||||
/// Converts an [`ort_sys::OrtStatusPtr`] to a [`Result`].
|
||||
///
|
||||
/// This takes ownership of the status pointer.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must be a valid `OrtStatusPtr` returned from an `ort-sys` API.
|
||||
#[inline]
|
||||
pub unsafe fn result_from_status(ptr: ort_sys::OrtStatusPtr) -> Result<(), Self> {
|
||||
match NonNull::new(ptr.0) {
|
||||
None => Ok(()),
|
||||
Some(ptr) => {
|
||||
cold();
|
||||
|
||||
Err(Self {
|
||||
recover: (),
|
||||
inner: Box::new(unsafe { ErrorInternal::from_ptr(ptr) })
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a custom, user-provided error in an [`ort::Error`](Error).
|
||||
///
|
||||
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
|
||||
/// related operation fails.
|
||||
pub fn wrap<T: CoreError + Send + Sync + 'static>(err: T) -> Self {
|
||||
Error {
|
||||
code: ErrorCode::GenericFailure,
|
||||
msg: err.to_string()
|
||||
}
|
||||
Self::new_internal(ErrorCode::GenericFailure, err.to_string(), Some(Box::new(err)))
|
||||
}
|
||||
|
||||
/// Creates a custom [`Error`] with the given message.
|
||||
pub fn new(msg: impl Into<String>) -> Self {
|
||||
Error {
|
||||
code: ErrorCode::GenericFailure,
|
||||
msg: msg.into()
|
||||
}
|
||||
Self::new_internal(ErrorCode::GenericFailure, msg, None)
|
||||
}
|
||||
|
||||
/// Creates a custom [`Error`] with the given [`ErrorCode`] and message.
|
||||
pub fn new_with_code(code: ErrorCode, msg: impl Into<String>) -> Self {
|
||||
Error { code, msg: msg.into() }
|
||||
Self::new_internal(code, msg, None)
|
||||
}
|
||||
|
||||
fn new_internal(code: ErrorCode, message: impl Into<String>, cause: Option<Box<dyn CoreError + Send + Sync + 'static>>) -> Self {
|
||||
let message = message.into();
|
||||
let ptr = with_cstr(message.as_bytes(), &|message| Ok(ortsys![unsafe CreateStatus(code.into(), message.as_ptr())])).expect("invalid error message");
|
||||
Self {
|
||||
recover: (),
|
||||
inner: Box::new(ErrorInternal {
|
||||
code,
|
||||
message,
|
||||
cause,
|
||||
status_ptr: unsafe { NonNull::new_unchecked(ptr.0) }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_recover<R>(self, recover: R) -> Error<R> {
|
||||
Error { recover, inner: self.inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Error<R> {
|
||||
pub fn code(&self) -> ErrorCode {
|
||||
self.code
|
||||
self.inner.code
|
||||
}
|
||||
|
||||
pub fn message(&self) -> &str {
|
||||
self.msg.as_str()
|
||||
self.inner.message.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
impl<R: Sized> Error<R> {
|
||||
/// Recovers from this error.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::All)
|
||||
/// // Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
|
||||
/// .unwrap_or_else(|e| e.recover())
|
||||
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn recover(self) -> R {
|
||||
self.recover
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> fmt::Display for Error<R> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(&self.msg)
|
||||
f.write_str(&self.inner.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl CoreError for Error {}
|
||||
impl<R> CoreError for Error<R> {
|
||||
fn source(&self) -> Option<&(dyn CoreError + 'static)> {
|
||||
self.inner.cause.as_ref().map(|x| &**x as &dyn CoreError)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<dyn CoreError + Send + Sync + 'static>> for Error {
|
||||
fn from(err: Box<dyn CoreError + Send + Sync + 'static>) -> Self {
|
||||
Error {
|
||||
code: ErrorCode::GenericFailure,
|
||||
msg: err.to_string()
|
||||
}
|
||||
Error::new_internal(ErrorCode::GenericFailure, err.to_string(), Some(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,6 +230,12 @@ impl From<alloc::ffi::IntoStringError> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error<crate::session::builder::SessionBuilder>> for Error<()> {
|
||||
fn from(err: Error<crate::session::builder::SessionBuilder>) -> Self {
|
||||
Self { recover: (), inner: err.inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum ErrorCode {
|
||||
@@ -177,39 +292,3 @@ impl From<ErrorCode> for ort_sys::OrtErrorCode {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts an [`ort_sys::OrtStatusPtr`] to a [`Result`].
|
||||
///
|
||||
/// **Note that this frees `status`!**
|
||||
///
|
||||
/// # Safety
|
||||
/// The value contained in `status` must be a valid [`ort_sys::OrtStatus`] pointer, or a null pointer (in which case the
|
||||
/// result will be `Ok`).
|
||||
#[inline]
|
||||
pub unsafe fn status_to_result(status: ort_sys::OrtStatusPtr) -> Result<(), Error> {
|
||||
let status = status.0;
|
||||
if status.is_null() {
|
||||
Ok(())
|
||||
} else {
|
||||
#[cold]
|
||||
fn status_to_error(status: *mut ort_sys::OrtStatus) -> Error {
|
||||
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(status)]);
|
||||
let raw: *const c_char = ortsys![unsafe GetErrorMessage(status)];
|
||||
match char_p_to_string(raw) {
|
||||
Ok(msg) => {
|
||||
ortsys![unsafe ReleaseStatus(status)];
|
||||
Error { code, msg }
|
||||
}
|
||||
Err(err) => {
|
||||
ortsys![unsafe ReleaseStatus(status)];
|
||||
Error {
|
||||
code,
|
||||
msg: format!("(failed to convert UTF-8: {err})")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(status_to_error(status))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,19 +288,19 @@ macro_rules! ortsys {
|
||||
unsafe { ($crate::api().$method)($($n),*) }
|
||||
};
|
||||
(@ort: unsafe $method:ident($($n:expr),*) as Result) => {
|
||||
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }
|
||||
unsafe { $crate::error::Error::result_from_status(($crate::api().$method)($($n),+)) }
|
||||
};
|
||||
(@$api:ident: unsafe $method:ident($($n:expr),*)) => {
|
||||
unsafe { ($crate::api::$api().unwrap().$method)($($n),+) }
|
||||
};
|
||||
(@$api:ident: unsafe $method:ident($($n:expr),*)?) => {
|
||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })?
|
||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })?
|
||||
};
|
||||
(@$api:ident: unsafe $method:ident($($n:expr),*)?; nonNull($($check:ident),+)$(;)?) => {
|
||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })?;
|
||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })?;
|
||||
ortsys![@nonNull?; $($check),+];
|
||||
};
|
||||
(@$api:ident: unsafe $method:ident($($n:expr),*) as Result) => {
|
||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::status_to_result((api.$method)($($n),+)) })
|
||||
$crate::api::$api().and_then(|api| unsafe { $crate::error::Error::result_from_status((api.$method)($($n),+)) })
|
||||
};
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use std::sync::Mutex;
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use crate::{
|
||||
Error,
|
||||
error::Result,
|
||||
session::{SessionOutputs, SharedSessionInner, UntypedRunOptions},
|
||||
util::{STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS},
|
||||
@@ -121,7 +122,7 @@ pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mu
|
||||
|
||||
crate::logging::drop!(AsyncInferenceContext, user_data);
|
||||
|
||||
if let Err(e) = unsafe { crate::error::status_to_result(status) } {
|
||||
if let Err(e) = unsafe { Error::result_from_status(status) } {
|
||||
ctx.inner.emplace_value(Err(e));
|
||||
ctx.inner.wake();
|
||||
return;
|
||||
|
||||
@@ -35,90 +35,10 @@ impl SessionBuilder {
|
||||
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
||||
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))]
|
||||
pub fn commit_from_url(&mut self, model_url: impl AsRef<str>) -> Result<Session> {
|
||||
let downloaded_path = SessionBuilder::download(model_url.as_ref())?;
|
||||
let downloaded_path = download_model(model_url.as_ref())?;
|
||||
self.commit_from_file(downloaded_path)
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
||||
fn download(url: &str) -> Result<PathBuf> {
|
||||
use ureq::{
|
||||
config::Config,
|
||||
tls::{RootCerts, TlsConfig, TlsProvider}
|
||||
};
|
||||
|
||||
let mut download_dir = ort_sys::internal::dirs::cache_dir()
|
||||
.expect("could not determine cache directory")
|
||||
.join("models");
|
||||
if std::fs::create_dir_all(&download_dir).is_err() {
|
||||
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
|
||||
}
|
||||
|
||||
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
|
||||
let _ = write!(&mut s, "{:02x}", b);
|
||||
s
|
||||
});
|
||||
let model_filepath = download_dir.join(&model_filename);
|
||||
if model_filepath.exists() {
|
||||
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
|
||||
|
||||
let agent = Config::builder()
|
||||
.tls_config(
|
||||
TlsConfig::builder()
|
||||
.root_certs(RootCerts::WebPki)
|
||||
.provider(if cfg!(any(feature = "tls-rustls", feature = "tls-rustls-no-provider")) {
|
||||
TlsProvider::Rustls
|
||||
} else if cfg!(any(feature = "tls-native", feature = "tls-native-vendored")) {
|
||||
TlsProvider::NativeTls
|
||||
} else {
|
||||
return Err(Error::new(
|
||||
"No TLS provider configured. When using `fetch-models` with HTTPS URLs, a `tls-*` feature must be enabled."
|
||||
));
|
||||
})
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
.new_agent();
|
||||
|
||||
let resp = agent.get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;
|
||||
|
||||
let len = resp
|
||||
.headers()
|
||||
.get("Content-Length")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.expect("Missing Content-Length header");
|
||||
crate::info!(len, "Downloading {} bytes", len);
|
||||
|
||||
let mut reader = resp.into_body().into_with_config().limit(u64::MAX).reader();
|
||||
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
|
||||
|
||||
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
|
||||
let mut writer = std::io::BufWriter::new(f);
|
||||
|
||||
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
|
||||
if bytes_io_count != len as u64 {
|
||||
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
|
||||
}
|
||||
|
||||
drop(writer);
|
||||
|
||||
match std::fs::rename(&temp_filepath, &model_filepath) {
|
||||
Ok(()) => Ok(model_filepath),
|
||||
Err(e) => {
|
||||
if model_filepath.exists() {
|
||||
let _ = std::fs::remove_file(temp_filepath);
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
Err(Error::new(format!("Failed to download model: {e}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads an ONNX model from a file and builds the session.
|
||||
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
@@ -331,3 +251,81 @@ impl SessionBuilder {
|
||||
EditableSession::new(session_ptr, self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "fetch-models", feature = "std", not(target_arch = "wasm32")))]
|
||||
fn download_model(url: &str) -> Result<PathBuf> {
|
||||
use ureq::{
|
||||
config::Config,
|
||||
tls::{RootCerts, TlsConfig, TlsProvider}
|
||||
};
|
||||
|
||||
let mut download_dir = ort_sys::internal::dirs::cache_dir()
|
||||
.expect("could not determine cache directory")
|
||||
.join("models");
|
||||
if std::fs::create_dir_all(&download_dir).is_err() {
|
||||
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
|
||||
}
|
||||
|
||||
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
|
||||
let _ = write!(&mut s, "{:02x}", b);
|
||||
s
|
||||
});
|
||||
let model_filepath = download_dir.join(&model_filename);
|
||||
if model_filepath.exists() {
|
||||
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
|
||||
|
||||
let agent = Config::builder()
|
||||
.tls_config(
|
||||
TlsConfig::builder()
|
||||
.root_certs(RootCerts::WebPki)
|
||||
.provider(if cfg!(any(feature = "tls-rustls", feature = "tls-rustls-no-provider")) {
|
||||
TlsProvider::Rustls
|
||||
} else if cfg!(any(feature = "tls-native", feature = "tls-native-vendored")) {
|
||||
TlsProvider::NativeTls
|
||||
} else {
|
||||
return Err(Error::new("No TLS provider configured. When using `fetch-models` with HTTPS URLs, a `tls-*` feature must be enabled."));
|
||||
})
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
.new_agent();
|
||||
|
||||
let resp = agent.get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;
|
||||
|
||||
let len = resp
|
||||
.headers()
|
||||
.get("Content-Length")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.expect("Missing Content-Length header");
|
||||
crate::info!(len, "Downloading {} bytes", len);
|
||||
|
||||
let mut reader = resp.into_body().into_with_config().limit(u64::MAX).reader();
|
||||
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
|
||||
|
||||
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
|
||||
let mut writer = std::io::BufWriter::new(f);
|
||||
|
||||
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
|
||||
if bytes_io_count != len as u64 {
|
||||
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
|
||||
}
|
||||
|
||||
drop(writer);
|
||||
|
||||
match std::fs::rename(&temp_filepath, &model_filepath) {
|
||||
Ok(()) => Ok(model_filepath),
|
||||
Err(e) => {
|
||||
if model_filepath.exists() {
|
||||
let _ = std::fs::remove_file(temp_filepath);
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
Err(Error::new(format!("Failed to download model: {e}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use super::SessionBuilder;
|
||||
use crate::Result;
|
||||
use super::{BuilderResult, SessionBuilder};
|
||||
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
|
||||
|
||||
@@ -7,17 +6,21 @@ impl SessionBuilder {
|
||||
/// Enable/disable the usage of prepacking.
|
||||
///
|
||||
/// This option is **enabled** by default.
|
||||
pub fn with_prepacking(mut self, enable: bool) -> Result<Self> {
|
||||
self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?;
|
||||
Ok(self)
|
||||
pub fn with_prepacking(mut self, enable: bool) -> BuilderResult {
|
||||
match self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" }) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Use allocators from the registered environment.
|
||||
///
|
||||
/// This option is **disabled** by default.
|
||||
pub fn with_env_allocators(mut self) -> Result<Self> {
|
||||
self.add_config_entry("session.use_env_allocators", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_env_allocators(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("session.use_env_allocators", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Disables subnormal floats by enabling the denormals-are-zero and flush-to-zero flags for all threads in the
|
||||
@@ -28,94 +31,137 @@ impl SessionBuilder {
|
||||
/// giving faster & more consistent performance, but lower accuracy (in cases where subnormals are involved).
|
||||
///
|
||||
/// This option is **disabled** by default, as it may hurt model accuracy.
|
||||
pub fn with_flush_to_zero(mut self) -> Result<Self> {
|
||||
self.add_config_entry("session.set_denormal_as_zero", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_flush_to_zero(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("session.set_denormal_as_zero", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable/disable fusion for quantized models in QDQ (`QuantizeLinear`/`DequantizeLinear`) format.
|
||||
///
|
||||
/// This option is **enabled** by default for all EPs except DirectML.
|
||||
pub fn with_quant_qdq(mut self, enable: bool) -> Result<Self> {
|
||||
self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?;
|
||||
Ok(self)
|
||||
pub fn with_quant_qdq(mut self, enable: bool) -> BuilderResult {
|
||||
match self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" }) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable/disable the optimization step removing double QDQ nodes.
|
||||
///
|
||||
/// This option is **enabled** by default.
|
||||
pub fn with_double_qdq_remover(mut self, enable: bool) -> Result<Self> {
|
||||
self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?;
|
||||
Ok(self)
|
||||
pub fn with_double_qdq_remover(mut self, enable: bool) -> BuilderResult {
|
||||
match self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" }) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed.
|
||||
///
|
||||
/// This option is **disabled** by default.
|
||||
pub fn with_qdq_cleanup(mut self) -> Result<Self> {
|
||||
self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_qdq_cleanup(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("session.enable_quant_qdq_cleanup", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable fast tanh-based GELU approximation (like PyTorch's `nn.GELU(approximate='tanh')`).
|
||||
///
|
||||
/// This option is **disabled** by default, as it may impact results.
|
||||
pub fn with_approximate_gelu(mut self) -> Result<Self> {
|
||||
self.add_config_entry("optimization.enable_gelu_approximation", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_approximate_gelu(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("optimization.enable_gelu_approximation", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable the `Cast` chain elimination optimization.
|
||||
///
|
||||
/// This option is **disabled** by default, as it may impact results.
|
||||
pub fn with_cast_chain_elimination(mut self) -> Result<Self> {
|
||||
self.add_config_entry("optimization.enable_cast_chain_elimination", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_cast_chain_elimination(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("optimization.enable_cast_chain_elimination", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable/disable ahead-of-time function inlining.
|
||||
///
|
||||
/// This option is **enabled** by default.
|
||||
pub fn with_aot_inlining(mut self, enable: bool) -> Result<Self> {
|
||||
self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?;
|
||||
Ok(self)
|
||||
pub fn with_aot_inlining(mut self, enable: bool) -> BuilderResult {
|
||||
match self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" }) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Accepts a comma-separated list of optimizers to disable.
|
||||
pub fn with_disabled_optimizers(mut self, optimizers: impl AsRef<str>) -> Result<Self> {
|
||||
self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?;
|
||||
Ok(self)
|
||||
pub fn with_disabled_optimizers(mut self, optimizers: impl AsRef<str>) -> BuilderResult {
|
||||
match self.add_config_entry("optimization.disable_specified_optimizers", optimizers) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable using the device allocator for allocating initialized tensor memory, potentially bypassing arena
|
||||
/// allocators.
|
||||
///
|
||||
/// This option is **disabled** by default.
|
||||
pub fn with_device_allocated_initializers(mut self) -> Result<Self> {
|
||||
self.add_config_entry("session.use_device_allocator_for_initializers", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_device_allocated_initializers(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("session.use_device_allocator_for_initializers", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable/disable allowing the inter-op threads to spin for a short period before blocking.
|
||||
///
|
||||
/// This option is **enabled** by defualt.
|
||||
pub fn with_inter_op_spinning(mut self, enable: bool) -> Result<Self> {
|
||||
self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?;
|
||||
Ok(self)
|
||||
pub fn with_inter_op_spinning(mut self, enable: bool) -> BuilderResult {
|
||||
match self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" }) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable/disable allowing the intra-op threads to spin for a short period before blocking.
|
||||
///
|
||||
/// This option is **enabled** by defualt.
|
||||
pub fn with_intra_op_spinning(mut self, enable: bool) -> Result<Self> {
|
||||
self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?;
|
||||
Ok(self)
|
||||
pub fn with_intra_op_spinning(mut self, enable: bool) -> BuilderResult {
|
||||
match self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" }) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Disables falling back to the CPU for operations not supported by any other EP.
|
||||
/// Models with graphs that cannot be placed entirely on the EP(s) will fail to commit.
|
||||
pub fn with_disable_cpu_fallback(mut self) -> Result<Self> {
|
||||
self.add_config_entry("session.disable_cpu_ep_fallback", "1")?;
|
||||
Ok(self)
|
||||
pub fn with_disable_cpu_fallback(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("session.disable_cpu_ep_fallback", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Uses slower U8U8 matrix multiplication in place of U8S8 matrix multiplication that could potentially overflow on
|
||||
/// x86-64 platforms without the VNNI extension.
|
||||
///
|
||||
/// This should only be enabled if you encounter overflow issues with quantized models.
|
||||
pub fn with_precise_qmm(mut self) -> BuilderResult {
|
||||
match self.add_config_entry("session.x64quantprecision", "1") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables dynamic thread block sizing with the given base block size.
|
||||
pub fn with_dynamic_block_base(mut self, size: u32) -> BuilderResult {
|
||||
match self.add_config_entry("session.dynamic_block_base", size.to_string()) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,14 +7,13 @@ use core::{
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::Path;
|
||||
|
||||
use super::SessionBuilder;
|
||||
use super::{BuilderResult, SessionBuilder};
|
||||
#[cfg(feature = "std")]
|
||||
use crate::util::path_to_os_char;
|
||||
use crate::{
|
||||
AsPointer, Error, ErrorCode,
|
||||
environment::{self, ThreadManager},
|
||||
ep::{ExecutionProviderDispatch, apply_execution_providers},
|
||||
error::Result,
|
||||
logging::{LogLevel, LoggerFunction},
|
||||
memory::MemoryInfo,
|
||||
operator::OperatorDomain,
|
||||
@@ -36,9 +35,11 @@ impl SessionBuilder {
|
||||
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
|
||||
/// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec`
|
||||
/// of [`ExecutionProviderDispatch`]es.
|
||||
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result<Self> {
|
||||
apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?;
|
||||
Ok(self)
|
||||
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> BuilderResult {
|
||||
match apply_execution_providers(&mut self, execution_providers.as_ref(), "session options") {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure the session to use a number of threads to parallelize the execution within nodes. If ONNX Runtime was
|
||||
@@ -48,9 +49,11 @@ impl SessionBuilder {
|
||||
///
|
||||
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
|
||||
/// [`SessionBuilder::with_inter_threads()`].
|
||||
pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
|
||||
ortsys![unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
||||
Ok(self)
|
||||
pub fn with_intra_threads(mut self, num_threads: usize) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure the session to use a number of threads to parallelize the execution of the graph. If nodes can be run
|
||||
@@ -60,9 +63,11 @@ impl SessionBuilder {
|
||||
///
|
||||
/// For configuring the number of threads used to parallelize the execution within nodes, see
|
||||
/// [`SessionBuilder::with_intra_threads()`].
|
||||
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
|
||||
ortsys![unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
|
||||
Ok(self)
|
||||
pub fn with_inter_threads(mut self, num_threads: usize) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable/disable the parallel execution mode for this session. By default, this is disabled.
|
||||
@@ -70,21 +75,25 @@ impl SessionBuilder {
|
||||
/// Parallel execution can improve performance for models with many branches, at the cost of higher memory usage.
|
||||
/// You can configure the amount of threads used to parallelize the execution of the graph via
|
||||
/// [`SessionBuilder::with_inter_threads()`].
|
||||
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> Result<Self> {
|
||||
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> BuilderResult {
|
||||
let execution_mode = if parallel_execution {
|
||||
ort_sys::ExecutionMode::ORT_PARALLEL
|
||||
} else {
|
||||
ort_sys::ExecutionMode::ORT_SEQUENTIAL
|
||||
};
|
||||
ortsys![unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode)?];
|
||||
Ok(self)
|
||||
match ortsys![@ort: unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the session's optimization level. See [`GraphOptimizationLevel`] for more information on the different
|
||||
/// optimization levels.
|
||||
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Result<Self> {
|
||||
ortsys![unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into())?];
|
||||
Ok(self)
|
||||
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into()) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// After performing optimization (configurable with [`SessionBuilder::with_optimization_level`]), serializes the
|
||||
@@ -93,10 +102,12 @@ impl SessionBuilder {
|
||||
/// Note that the file will only be created after the model is committed.
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> Result<Self> {
|
||||
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> BuilderResult {
|
||||
let path = crate::util::path_to_os_char(path);
|
||||
ortsys![unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr())?];
|
||||
Ok(self)
|
||||
match ortsys![@ort: unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr()) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
|
||||
@@ -105,29 +116,37 @@ impl SessionBuilder {
|
||||
/// [`Session::end_profiling`]: crate::session::Session::end_profiling
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> Result<Self> {
|
||||
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> BuilderResult {
|
||||
let profiling_file = crate::util::path_to_os_char(profiling_file);
|
||||
ortsys![unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr())?];
|
||||
Ok(self)
|
||||
match ortsys![@ort: unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr()) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables/disables memory pattern optimization. Disable it if the input size varies, i.e., dynamic batch
|
||||
pub fn with_memory_pattern(mut self, enable: bool) -> Result<Self> {
|
||||
if enable {
|
||||
ortsys![unsafe EnableMemPattern(self.ptr_mut())?];
|
||||
pub fn with_memory_pattern(mut self, enable: bool) -> BuilderResult {
|
||||
let result = if enable {
|
||||
ortsys![@ort: unsafe EnableMemPattern(self.ptr_mut()) as Result]
|
||||
} else {
|
||||
ortsys![unsafe DisableMemPattern(self.ptr_mut())?];
|
||||
ortsys![@ort: unsafe DisableMemPattern(self.ptr_mut()) as Result]
|
||||
};
|
||||
match result {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Configure this session to use a custom allocator, rather than the global default. This allocator is responsible
|
||||
/// for allocating the *metadata* associated with values -- not the contents of the values themselves; that's
|
||||
/// handled by the active execution providers. As such, only CPU-accessible allocators are allowed.
|
||||
pub fn with_allocator(mut self, info: MemoryInfo) -> Result<Self> {
|
||||
pub fn with_allocator(mut self, info: MemoryInfo) -> BuilderResult {
|
||||
if !info.is_cpu_accessible() {
|
||||
return Err(Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator"));
|
||||
return Err(
|
||||
Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator").with_recover(self)
|
||||
);
|
||||
}
|
||||
|
||||
self.memory_info = Some(Arc::new(info));
|
||||
Ok(self)
|
||||
}
|
||||
@@ -135,118 +154,146 @@ impl SessionBuilder {
|
||||
/// Registers a custom operator library at the given library path.
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> Result<Self> {
|
||||
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> BuilderResult {
|
||||
let path_cstr = path_to_os_char(lib_path);
|
||||
ortsys![unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr())?];
|
||||
Ok(self)
|
||||
match ortsys![@ort: unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr()) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators.
|
||||
pub fn with_extensions(mut self) -> Result<Self> {
|
||||
ortsys![unsafe EnableOrtCustomOps(self.ptr_mut())?];
|
||||
Ok(self)
|
||||
pub fn with_extensions(mut self) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe EnableOrtCustomOps(self.ptr_mut()) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> Result<Self> {
|
||||
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> BuilderResult {
|
||||
let domain: Arc<OperatorDomain> = domain.into();
|
||||
ortsys![unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut())?];
|
||||
match ortsys![@ort: unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut()) as Result] {
|
||||
Ok(()) => {
|
||||
self.operator_domains.push(domain);
|
||||
Ok(self)
|
||||
}
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables/disables deterministic computation.
|
||||
///
|
||||
/// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance.
|
||||
/// Enabling deterministic compute will output reproducible results, but may come at a performance penalty.
|
||||
pub fn with_deterministic_compute(mut self, enable: bool) -> Result<Self> {
|
||||
ortsys![unsafe SetDeterministicCompute(self.ptr_mut(), enable)?];
|
||||
Ok(self)
|
||||
pub fn with_deterministic_compute(mut self, enable: bool) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SetDeterministicCompute(self.ptr_mut(), enable) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> Result<Self> {
|
||||
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
|
||||
let ptr = self.ptr_mut();
|
||||
let value: Arc<DynValue> = value.into();
|
||||
with_cstr(name.as_ref().as_bytes(), &|name| {
|
||||
ortsys![unsafe AddInitializer(ptr, name.as_ptr(), value.ptr())?];
|
||||
Ok(())
|
||||
})?;
|
||||
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddInitializer(ptr, name.as_ptr(), value.ptr()) as Result]) {
|
||||
Ok(()) => {
|
||||
self.initializers.push(value);
|
||||
Ok(self)
|
||||
}
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> Result<Self> {
|
||||
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
|
||||
let ptr = self.ptr_mut();
|
||||
let value: Arc<DynValue> = value.into();
|
||||
with_cstr(name.as_ref().as_bytes(), &|name| {
|
||||
ortsys![unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1)?];
|
||||
Ok(())
|
||||
})?;
|
||||
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1) as Result]) {
|
||||
Ok(()) => {
|
||||
self.initializers.push(value);
|
||||
Ok(self)
|
||||
}
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "std", feature = "api-18"))]
|
||||
#[cfg_attr(docsrs, doc(cfg(all(feature = "std", feature = "api-18"))))]
|
||||
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> Result<Self> {
|
||||
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> BuilderResult {
|
||||
let file_name = path_to_os_char(file_name);
|
||||
let sizes = [buffer.len()];
|
||||
ortsys![unsafe AddExternalInitializersFromMemory(self.ptr_mut(), &file_name.as_ptr(), &buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(), sizes.as_ptr(), 1)?];
|
||||
match ortsys![@ort:
|
||||
unsafe AddExternalInitializersFromMemory(
|
||||
self.ptr_mut(),
|
||||
&file_name.as_ptr(),
|
||||
&buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(),
|
||||
sizes.as_ptr(),
|
||||
1
|
||||
) as Result
|
||||
] {
|
||||
Ok(()) => {
|
||||
self.external_initializer_buffers.push(buffer);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_log_id(mut self, id: impl AsRef<str>) -> Result<Self> {
|
||||
let ptr = self.ptr_mut();
|
||||
with_cstr(id.as_ref().as_bytes(), &|id| {
|
||||
ortsys![unsafe SetSessionLogId(ptr, id.as_ptr())?];
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(self)
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> Result<Self> {
|
||||
pub fn with_log_id(mut self, id: impl AsRef<str>) -> BuilderResult {
|
||||
let ptr = self.ptr_mut();
|
||||
with_cstr(name.as_ref().as_bytes(), &|name| {
|
||||
ortsys![unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size)?];
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(self)
|
||||
match with_cstr(id.as_ref().as_bytes(), &|id| ortsys![@ort: unsafe SetSessionLogId(ptr, id.as_ptr()) as Result]) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> Result<Self> {
|
||||
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> BuilderResult {
|
||||
let ptr = self.ptr_mut();
|
||||
with_cstr(denotation.as_ref().as_bytes(), &|denotation| {
|
||||
ortsys![unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size)?];
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(self)
|
||||
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size) as Result]) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> Result<Self> {
|
||||
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> BuilderResult {
|
||||
let ptr = self.ptr_mut();
|
||||
match with_cstr(denotation.as_ref().as_bytes(), &|denotation| ortsys![@ort: unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size) as Result])
|
||||
{
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> BuilderResult {
|
||||
self.prepacked_weights = Some(weights.clone());
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Configures this environment to use its own thread pool instead of defaulting to the
|
||||
/// [`Environment`](crate::environment::Environment)'s global thread pool if one was defined.
|
||||
pub fn with_independent_thread_pool(mut self) -> Result<Self> {
|
||||
pub fn with_independent_thread_pool(mut self) -> BuilderResult {
|
||||
self.no_global_thread_pool = true;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_no_environment_execution_providers(mut self) -> Result<Self> {
|
||||
pub fn with_no_environment_execution_providers(mut self) -> BuilderResult {
|
||||
self.no_env_eps = true;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
||||
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> BuilderResult {
|
||||
let manager = Arc::new(manager);
|
||||
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
||||
ortsys![unsafe SessionOptionsSetCustomCreateThreadFn(self.ptr_mut(), Some(environment::thread_create::<T>))?];
|
||||
ortsys![unsafe SessionOptionsSetCustomJoinThreadFn(self.ptr_mut(), Some(environment::thread_join::<T>))?];
|
||||
let ptr = self.ptr_mut();
|
||||
match ortsys![@ort: unsafe SessionOptionsSetCustomThreadCreationOptions(ptr, (&*manager as *const T) as *mut c_void) as Result]
|
||||
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomCreateThreadFn(ptr, Some(environment::thread_create::<T>)) as Result])
|
||||
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomJoinThreadFn(ptr, Some(environment::thread_join::<T>)) as Result])
|
||||
{
|
||||
Ok(()) => {
|
||||
self.thread_manager = Some(manager as Arc<dyn Any>);
|
||||
Ok(self)
|
||||
}
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures this session to use a custom logger function.
|
||||
///
|
||||
@@ -268,12 +315,16 @@ impl SessionBuilder {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn with_logger(mut self, logger: LoggerFunction) -> Result<Self> {
|
||||
pub fn with_logger(mut self, logger: LoggerFunction) -> BuilderResult {
|
||||
let logger = Arc::new(logger);
|
||||
ortsys![unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void)?];
|
||||
match ortsys![@ort: unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void) as Result] {
|
||||
Ok(()) => {
|
||||
self.logger = Some(logger);
|
||||
Ok(self)
|
||||
}
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the severity level for messages logged by this session.
|
||||
///
|
||||
@@ -281,15 +332,19 @@ impl SessionBuilder {
|
||||
/// precedence, i.e. if the application was initialized with `ort`'s log level set to `warn` via the `RUST_LOG`
|
||||
/// environment variable or similar, setting a session's log severity level to `verbose` will still have it only
|
||||
/// log `warn` messages or higher.`
|
||||
pub fn with_log_level(mut self, level: LogLevel) -> Result<Self> {
|
||||
ortsys![unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _)?];
|
||||
Ok(self)
|
||||
pub fn with_log_level(mut self, level: LogLevel) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls the level of verbosity for messages logged under [`LogLevel::Verbose`]; higher values = more verbose.
|
||||
pub fn with_log_verbosity(mut self, verbosity: c_int) -> Result<Self> {
|
||||
ortsys![unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity)?];
|
||||
Ok(self)
|
||||
pub fn with_log_verbosity(mut self, verbosity: c_int) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy) based
|
||||
@@ -309,9 +364,11 @@ impl SessionBuilder {
|
||||
/// ```
|
||||
#[cfg(feature = "api-22")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> Result<Self> {
|
||||
ortsys![unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into())?];
|
||||
Ok(self)
|
||||
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> BuilderResult {
|
||||
match ortsys![@ort: unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into()) as Result] {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,22 @@ mod impl_options;
|
||||
pub use self::editable::*;
|
||||
pub use self::impl_options::*;
|
||||
|
||||
/// `Result` type returned by [`SessionBuilder`] methods.
|
||||
///
|
||||
/// This type supports [error recovery](Error::recover):
|
||||
/// ```
|
||||
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::All)
|
||||
/// // Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
|
||||
/// .unwrap_or_else(|e| e.recover())
|
||||
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub type BuilderResult = Result<SessionBuilder, Error<SessionBuilder>>;
|
||||
|
||||
/// Creates a session using the builder pattern.
|
||||
///
|
||||
/// Once configured, use the
|
||||
@@ -146,7 +162,7 @@ impl SessionBuilder {
|
||||
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
/// # use std::{thread, time::Duration};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let builder = Session::builder()?
|
||||
/// let mut builder = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
/// .with_intra_threads(1)?;
|
||||
///
|
||||
@@ -168,9 +184,11 @@ impl SessionBuilder {
|
||||
}
|
||||
|
||||
/// Adds a custom configuration entry to the session.
|
||||
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
|
||||
self.add_config_entry(key.as_ref(), value.as_ref())?;
|
||||
Ok(self)
|
||||
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> BuilderResult {
|
||||
match self.add_config_entry(key.as_ref(), value.as_ref()) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +221,7 @@ impl LoadCanceler {
|
||||
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
/// # use std::{thread, time::Duration};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let builder = Session::builder()?
|
||||
/// let mut builder = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
/// .with_intra_threads(1)?;
|
||||
///
|
||||
|
||||
@@ -34,7 +34,7 @@ use smallvec::SmallVec;
|
||||
use crate::{
|
||||
AsPointer,
|
||||
environment::Environment,
|
||||
error::{Error, ErrorCode, Result, status_to_result},
|
||||
error::{Error, ErrorCode, Result},
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
util::{AllocatedString, STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr, with_cstr_ptr_array},
|
||||
@@ -708,7 +708,7 @@ pub(crate) mod io {
|
||||
) -> Result<usize> {
|
||||
let mut num_nodes = 0;
|
||||
let status = unsafe { f(session_ptr.as_ptr(), &mut num_nodes) };
|
||||
unsafe { status_to_result(status) }?;
|
||||
unsafe { Error::result_from_status(status) }?;
|
||||
Ok(num_nodes)
|
||||
}
|
||||
|
||||
@@ -721,7 +721,7 @@ pub(crate) mod io {
|
||||
let mut name_ptr: *mut c_char = ptr::null_mut();
|
||||
|
||||
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_ptr) };
|
||||
unsafe { status_to_result(status) }?;
|
||||
unsafe { Error::result_from_status(status) }?;
|
||||
if name_ptr.is_null() {
|
||||
crate::util::cold();
|
||||
return Err(crate::Error::new("expected `name_ptr` to not be null"));
|
||||
@@ -738,7 +738,7 @@ pub(crate) mod io {
|
||||
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||
|
||||
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
|
||||
unsafe { status_to_result(status) }?;
|
||||
unsafe { Error::result_from_status(status) }?;
|
||||
let Some(typeinfo_ptr) = NonNull::new(typeinfo_ptr) else {
|
||||
crate::util::cold();
|
||||
return Err(crate::Error::new("expected `typeinfo_ptr` to not be null"));
|
||||
|
||||
@@ -11,7 +11,7 @@ use super::{Checkpoint, Optimizer, training_api};
|
||||
use crate::{
|
||||
AsPointer,
|
||||
environment::Environment,
|
||||
error::{Result, status_to_result},
|
||||
error::{Error, Result},
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
|
||||
@@ -329,11 +329,11 @@ fn extract_io_names(
|
||||
) -> ort_sys::OrtStatusPtr
|
||||
) -> Result<Vec<String>> {
|
||||
let mut count = 0;
|
||||
unsafe { status_to_result(get_count(ptr.as_ptr(), &mut count)) }?;
|
||||
unsafe { Error::result_from_status(get_count(ptr.as_ptr(), &mut count)) }?;
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let mut name_bytes: *const c_char = ptr::null();
|
||||
unsafe { status_to_result(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
|
||||
unsafe { Error::result_from_status(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
|
||||
let name = match char_p_to_string(name_bytes) {
|
||||
Ok(name) => name,
|
||||
Err(e) => {
|
||||
|
||||
Reference in New Issue
Block a user