feat: support alternative backends

This commit is contained in:
Carson M.
2024-12-06 21:04:35 -06:00
parent 9f7ea0450a
commit d866abfc62
4 changed files with 35 additions and 15 deletions

View File

@@ -62,6 +62,8 @@ download-binaries = [ "ort-sys/download-binaries" ]
load-dynamic = [ "libloading", "ort-sys/load-dynamic" ]
copy-dylibs = [ "ort-sys/copy-dylibs" ]
alternative-backend = [ "ort-sys/disable-linking" ]
cuda = [ "ort-sys/cuda" ]
tensorrt = [ "ort-sys/tensorrt" ]
openvino = [ "ort-sys/openvino" ]

View File

@@ -21,6 +21,8 @@ download-binaries = [ "ureq", "tar", "flate2", "sha2" ]
load-dynamic = []
copy-dylibs = []
disable-linking = []
cuda = []
tensorrt = []
openvino = []

View File

@@ -563,7 +563,10 @@ fn real_main(link: bool) {
}
fn main() {
if env::var("DOCS_RS").is_ok() {
if env::var("DOCS_RS").is_ok() || cfg!(feature = "disable-linking") {
// On docs.rs, A) we don't need to link, and B) we don't have network, so we couldn't download anything if we wanted to.
// If `disable-linking` is specified, presumably the application will configure a custom backend, and the crate
// providing said backend will have its own linking logic, so no need to do anything.
return;
}

View File

@@ -33,12 +33,7 @@ pub mod value;
#[cfg(feature = "load-dynamic")]
use std::sync::Arc;
use std::{
ffi::CStr,
os::raw::c_char,
ptr::{self, NonNull},
sync::OnceLock
};
use std::{ffi::CStr, os::raw::c_char, ptr::NonNull, sync::OnceLock};
pub use ort_sys as sys;
@@ -109,6 +104,12 @@ pub fn info() -> &'static str {
unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(str.cast::<u8>(), len)) }
}
struct ApiPointer(NonNull<ort_sys::OrtApi>);
unsafe impl Send for ApiPointer {}
unsafe impl Sync for ApiPointer {}
static G_ORT_API: OnceLock<ApiPointer> = OnceLock::new();
/// Returns a reference to the global [`ort_sys::OrtApi`] object.
///
/// # Panics
@@ -116,12 +117,14 @@ pub fn info() -> &'static str {
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
pub fn api() -> &'static ort_sys::OrtApi {
struct ApiPointer(NonNull<ort_sys::OrtApi>);
unsafe impl Send for ApiPointer {}
unsafe impl Sync for ApiPointer {}
static G_ORT_API: OnceLock<ApiPointer> = OnceLock::new();
#[cfg(feature = "alternative-backend")]
let ptr = G_ORT_API
.get()
.expect(
"attempted to use `ort` APIs before initializing a backend\nwhen the `alternative-backend` feature is enabled, `ort::set_api` must be called to configure the `OrtApi` used by the library"
)
.0;
#[cfg(not(feature = "alternative-backend"))]
let ptr = G_ORT_API
.get_or_init(|| {
#[cfg(feature = "load-dynamic")]
@@ -131,7 +134,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
.get(b"OrtGetApiBase")
.expect("`OrtGetApiBase` must be present in ONNX Runtime dylib");
let base: *const ort_sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());
assert!(!base.is_null());
let version_string = ((*base).GetVersionString)();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
@@ -157,7 +160,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
#[cfg(not(feature = "load-dynamic"))]
unsafe {
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
assert_ne!(base, ptr::null());
assert!(!base.is_null());
let api: *const ort_sys::OrtApi = ((*base).GetApi)(ort_sys::ORT_API_VERSION);
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
}
@@ -166,6 +169,16 @@ pub fn api() -> &'static ort_sys::OrtApi {
unsafe { ptr.as_ref() }
}
pub fn set_api(api: ort_sys::OrtApi) -> bool {
match G_ORT_API.set(ApiPointer(unsafe { NonNull::new_unchecked(Box::leak(Box::new(api))) })) {
Ok(()) => true,
Err(api) => {
drop(unsafe { Box::from_raw(api.0.as_ptr()) });
false
}
}
}
/// Trait to access raw pointers from safe types which wrap unsafe [`ort_sys`] types.
pub trait AsPointer {
/// This safe type's corresponding [`ort_sys`] type.