mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: support alternative backends
This commit is contained in:
@@ -62,6 +62,8 @@ download-binaries = [ "ort-sys/download-binaries" ]
|
||||
load-dynamic = [ "libloading", "ort-sys/load-dynamic" ]
|
||||
copy-dylibs = [ "ort-sys/copy-dylibs" ]
|
||||
|
||||
alternative-backend = [ "ort-sys/disable-linking" ]
|
||||
|
||||
cuda = [ "ort-sys/cuda" ]
|
||||
tensorrt = [ "ort-sys/tensorrt" ]
|
||||
openvino = [ "ort-sys/openvino" ]
|
||||
|
||||
@@ -21,6 +21,8 @@ download-binaries = [ "ureq", "tar", "flate2", "sha2" ]
|
||||
load-dynamic = []
|
||||
copy-dylibs = []
|
||||
|
||||
disable-linking = []
|
||||
|
||||
cuda = []
|
||||
tensorrt = []
|
||||
openvino = []
|
||||
|
||||
@@ -563,7 +563,10 @@ fn real_main(link: bool) {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
if env::var("DOCS_RS").is_ok() {
|
||||
if env::var("DOCS_RS").is_ok() || cfg!(feature = "disable-linking") {
|
||||
// On docs.rs, A) we don't need to link, and B) we don't have network, so we couldn't download anything if we wanted to.
|
||||
// If `disable-linking` is specified, presumably the application will configure a custom backend, and the crate
|
||||
// providing said backend will have its own linking logic, so no need to do anything.
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
41
src/lib.rs
41
src/lib.rs
@@ -33,12 +33,7 @@ pub mod value;
|
||||
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
ffi::CStr,
|
||||
os::raw::c_char,
|
||||
ptr::{self, NonNull},
|
||||
sync::OnceLock
|
||||
};
|
||||
use std::{ffi::CStr, os::raw::c_char, ptr::NonNull, sync::OnceLock};
|
||||
|
||||
pub use ort_sys as sys;
|
||||
|
||||
@@ -109,6 +104,12 @@ pub fn info() -> &'static str {
|
||||
unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(str.cast::<u8>(), len)) }
|
||||
}
|
||||
|
||||
struct ApiPointer(NonNull<ort_sys::OrtApi>);
|
||||
unsafe impl Send for ApiPointer {}
|
||||
unsafe impl Sync for ApiPointer {}
|
||||
|
||||
static G_ORT_API: OnceLock<ApiPointer> = OnceLock::new();
|
||||
|
||||
/// Returns a reference to the global [`ort_sys::OrtApi`] object.
|
||||
///
|
||||
/// # Panics
|
||||
@@ -116,12 +117,14 @@ pub fn info() -> &'static str {
|
||||
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
|
||||
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
|
||||
pub fn api() -> &'static ort_sys::OrtApi {
|
||||
struct ApiPointer(NonNull<ort_sys::OrtApi>);
|
||||
unsafe impl Send for ApiPointer {}
|
||||
unsafe impl Sync for ApiPointer {}
|
||||
|
||||
static G_ORT_API: OnceLock<ApiPointer> = OnceLock::new();
|
||||
|
||||
#[cfg(feature = "alternative-backend")]
|
||||
let ptr = G_ORT_API
|
||||
.get()
|
||||
.expect(
|
||||
"attempted to use `ort` APIs before initializing a backend\nwhen the `alternative-backend` feature is enabled, `ort::set_api` must be called to configure the `OrtApi` used by the library"
|
||||
)
|
||||
.0;
|
||||
#[cfg(not(feature = "alternative-backend"))]
|
||||
let ptr = G_ORT_API
|
||||
.get_or_init(|| {
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
@@ -131,7 +134,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
.get(b"OrtGetApiBase")
|
||||
.expect("`OrtGetApiBase` must be present in ONNX Runtime dylib");
|
||||
let base: *const ort_sys::OrtApiBase = base_getter();
|
||||
assert_ne!(base, ptr::null());
|
||||
assert!(!base.is_null());
|
||||
|
||||
let version_string = ((*base).GetVersionString)();
|
||||
let version_string = CStr::from_ptr(version_string).to_string_lossy();
|
||||
@@ -157,7 +160,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
#[cfg(not(feature = "load-dynamic"))]
|
||||
unsafe {
|
||||
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
|
||||
assert_ne!(base, ptr::null());
|
||||
assert!(!base.is_null());
|
||||
let api: *const ort_sys::OrtApi = ((*base).GetApi)(ort_sys::ORT_API_VERSION);
|
||||
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
|
||||
}
|
||||
@@ -166,6 +169,16 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
unsafe { ptr.as_ref() }
|
||||
}
|
||||
|
||||
pub fn set_api(api: ort_sys::OrtApi) -> bool {
|
||||
match G_ORT_API.set(ApiPointer(unsafe { NonNull::new_unchecked(Box::leak(Box::new(api))) })) {
|
||||
Ok(()) => true,
|
||||
Err(api) => {
|
||||
drop(unsafe { Box::from_raw(api.0.as_ptr()) });
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to access raw pointers from safe types which wrap unsafe [`ort_sys`] types.
|
||||
pub trait AsPointer {
|
||||
/// This safe type's corresponding [`ort_sys`] type.
|
||||
|
||||
Reference in New Issue
Block a user