mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
fix: use separate error type for load-dynamic init, closes #560
This commit is contained in:
@@ -677,7 +677,7 @@ pub fn init() -> EnvironmentBuilder {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ort::ep;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # fn main() -> Result<(), ort::LoadDynamicError> {
|
||||
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
|
||||
/// ort::init_from(lib_path.join("onnxruntime.dll"))?
|
||||
/// .with_execution_providers([ep::CUDA::default().build()])
|
||||
@@ -694,7 +694,7 @@ pub fn init() -> EnvironmentBuilder {
|
||||
#[cfg(all(feature = "load-dynamic", not(target_arch = "wasm32")))]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn init_from<P: AsRef<std::path::Path>>(path: P) -> Result<EnvironmentBuilder> {
|
||||
crate::load_dylib_from_path(path.as_ref())?;
|
||||
pub fn init_from<P: AsRef<std::path::Path>>(path: P) -> Result<EnvironmentBuilder, crate::LoadDynamicError> {
|
||||
crate::load_dynamic::init(path.as_ref())?;
|
||||
Ok(EnvironmentBuilder::new())
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ macro_rules! define_ep_register {
|
||||
#[cfg(all(feature = "load-dynamic", not(target_arch = "wasm32")))]
|
||||
#[allow(non_snake_case)]
|
||||
let $symbol = unsafe {
|
||||
let dylib = $crate::G_ORT_LIB.get().expect("dylib not yet initialized");
|
||||
let dylib = $crate::load_dynamic::G_ORT_LIB.get().expect("dylib not yet initialized");
|
||||
let symbol: ::core::result::Result<
|
||||
::libloading::Symbol<unsafe extern "C" fn($($id: $type),*) -> $rt>,
|
||||
::libloading::Error
|
||||
|
||||
121
src/lib.rs
121
src/lib.rs
@@ -86,58 +86,91 @@ pub use self::{
|
||||
pub const MINOR_VERSION: u32 = ort_sys::ORT_API_VERSION;
|
||||
|
||||
#[cfg(all(feature = "load-dynamic", not(target_arch = "wasm32")))]
|
||||
pub(crate) static G_ORT_LIB: OnceLock<libloading::Library> = OnceLock::new();
|
||||
pub(crate) mod load_dynamic {
|
||||
use core::{ffi::CStr, fmt};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[cfg(all(feature = "load-dynamic", not(target_arch = "wasm32")))]
|
||||
pub(crate) fn load_dylib_from_path(path: &std::path::Path) -> Result<bool> {
|
||||
let mut inserter = Some(|| -> crate::Result<libloading::Library> {
|
||||
use core::cmp::Ordering;
|
||||
use crate::{MINOR_VERSION, util::OnceLock};
|
||||
|
||||
let absolute_path = if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
let relative = std::env::current_exe()
|
||||
.expect("could not get current executable path")
|
||||
.parent()
|
||||
.expect("executable is root?")
|
||||
.join(path);
|
||||
if relative.exists() { relative } else { path.to_path_buf() }
|
||||
};
|
||||
let lib =
|
||||
unsafe { libloading::Library::new(&absolute_path) }.map_err(|e| Error::new(format!("failed to load from `{}`: {e}", absolute_path.display())))?;
|
||||
#[derive(Debug)]
|
||||
pub enum LoadError {
|
||||
Dlopen { error: libloading::Error, path: PathBuf },
|
||||
MissingApi { path: PathBuf },
|
||||
BadVersion { version_str: String, path: PathBuf }
|
||||
}
|
||||
|
||||
let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> =
|
||||
unsafe { lib.get(b"OrtGetApiBase") }.map_err(|_| Error::new("expected `OrtGetApiBase` to be present in libonnxruntime"))?;
|
||||
let base: *const ort_sys::OrtApiBase = unsafe { base_getter() };
|
||||
assert!(!base.is_null());
|
||||
impl fmt::Display for LoadError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Dlopen { error, path } => f.write_fmt(format_args!("failed to load from `{}`: {error}", path.display())),
|
||||
Self::MissingApi { path } => f.write_fmt(format_args!("{} does not export `OrtGetApiBase`", path.display())),
|
||||
Self::BadVersion { version_str, path } => f.write_fmt(format_args!(
|
||||
"ort {} is not compatible with the ONNX Runtime binary found at `{}`; expected version >= '1.{}.x', but got '{version_str}'",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
path.display(),
|
||||
super::MINOR_VERSION
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let version_string = unsafe { ((*base).GetVersionString)() };
|
||||
let version_string = unsafe { CStr::from_ptr(version_string) }.to_string_lossy();
|
||||
impl core::error::Error for LoadError {}
|
||||
|
||||
let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
|
||||
match lib_minor_version.cmp(&MINOR_VERSION) {
|
||||
Ordering::Less => {
|
||||
return Err(Error::new(format!(
|
||||
"ort {} is not compatible with the ONNX Runtime binary found at `{}`; expected version >= '1.{MINOR_VERSION}.x', but got '{version_string}'",
|
||||
pub(crate) static G_ORT_LIB: OnceLock<libloading::Library> = OnceLock::new();
|
||||
|
||||
pub(crate) fn init(path: &std::path::Path) -> Result<bool, LoadError> {
|
||||
let mut inserter = Some(|| -> Result<libloading::Library, LoadError> {
|
||||
use core::cmp::Ordering;
|
||||
|
||||
let absolute_path = if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
let relative = std::env::current_exe()
|
||||
.expect("could not get current executable path")
|
||||
.parent()
|
||||
.expect("executable is root?")
|
||||
.join(path);
|
||||
if relative.exists() { relative } else { path.to_path_buf() }
|
||||
};
|
||||
let lib = unsafe { libloading::Library::new(&absolute_path) }.map_err(|e| LoadError::Dlopen {
|
||||
error: e,
|
||||
path: absolute_path.clone()
|
||||
})?;
|
||||
|
||||
let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> =
|
||||
unsafe { lib.get(b"OrtGetApiBase") }.map_err(|_| LoadError::MissingApi { path: absolute_path.clone() })?;
|
||||
let base: *const ort_sys::OrtApiBase = unsafe { base_getter() };
|
||||
assert!(!base.is_null());
|
||||
|
||||
let version_string = unsafe { ((*base).GetVersionString)() };
|
||||
let version_string = unsafe { CStr::from_ptr(version_string) }.to_string_lossy();
|
||||
|
||||
let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
|
||||
match lib_minor_version.cmp(&MINOR_VERSION) {
|
||||
Ordering::Less => {
|
||||
return Err(LoadError::BadVersion {
|
||||
version_str: version_string.to_string(),
|
||||
path: absolute_path
|
||||
});
|
||||
}
|
||||
Ordering::Greater => crate::info!(
|
||||
"ort {} was designed for ONNX Runtime '1.{MINOR_VERSION}.x' and may have compatibility issues with the ONNX Runtime binary found at `{}`, which is version '{version_string}'",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
absolute_path.display()
|
||||
)));
|
||||
}
|
||||
Ordering::Greater => crate::info!(
|
||||
"ort {} was designed for ONNX Runtime '1.{MINOR_VERSION}.x' and may have compatibility issues with the ONNX Runtime binary found at `{}`, which is version '{version_string}'",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
absolute_path.display()
|
||||
),
|
||||
Ordering::Equal => {}
|
||||
};
|
||||
),
|
||||
Ordering::Equal => {}
|
||||
};
|
||||
|
||||
crate::info!("Loaded ONNX Runtime dylib from \"{}\"; version '{version_string}'", absolute_path.display());
|
||||
crate::info!("Loaded ONNX Runtime dylib from \"{}\"; version '{version_string}'", absolute_path.display());
|
||||
|
||||
Ok(lib)
|
||||
});
|
||||
G_ORT_LIB.get_or_try_init(|| (unsafe { inserter.take().unwrap_unchecked() })())?;
|
||||
Ok(inserter.is_none())
|
||||
Ok(lib)
|
||||
});
|
||||
G_ORT_LIB.get_or_try_init(|| (unsafe { inserter.take().unwrap_unchecked() })())?;
|
||||
Ok(inserter.is_none())
|
||||
}
|
||||
}
|
||||
#[cfg(all(feature = "load-dynamic", not(target_arch = "wasm32")))]
|
||||
pub use self::load_dynamic::LoadError as LoadDynamicError;
|
||||
|
||||
/// Returns information about the build of ONNX Runtime used, including version, Git commit, and compile flags.
|
||||
///
|
||||
@@ -182,6 +215,8 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
fn setup_api() -> ApiPointer {
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
let base = unsafe {
|
||||
use crate::load_dynamic::G_ORT_LIB;
|
||||
|
||||
let dylib = if let Some(handle) = G_ORT_LIB.get() {
|
||||
handle
|
||||
} else {
|
||||
@@ -195,7 +230,7 @@ fn setup_api() -> ApiPointer {
|
||||
_ => "libonnxruntime.dylib".to_owned()
|
||||
}
|
||||
.into();
|
||||
load_dylib_from_path(&path).expect("Failed to load ONNX Runtime dylib");
|
||||
load_dynamic::init(&path).expect("Failed to load ONNX Runtime dylib");
|
||||
G_ORT_LIB.get_unchecked()
|
||||
};
|
||||
let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> = dylib
|
||||
|
||||
Reference in New Issue
Block a user