mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: #![no_std] (#343)
This commit is contained in:
21
Cargo.toml
21
Cargo.toml
@@ -52,18 +52,23 @@ strip = true
|
||||
codegen-units = 1
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = [ "ndarray", "half", "training", "fetch-models", "load-dynamic", "copy-dylibs" ]
|
||||
features = [ "std", "ndarray", "half", "training", "fetch-models", "load-dynamic", "copy-dylibs" ]
|
||||
targets = ["x86_64-unknown-linux-gnu"]
|
||||
rustdoc-args = [ "--cfg", "docsrs" ]
|
||||
|
||||
[features]
|
||||
default = [ "ndarray", "half", "tracing", "download-binaries", "copy-dylibs" ]
|
||||
default = [ "std", "ndarray", "half", "tracing", "download-binaries", "copy-dylibs" ]
|
||||
|
||||
std = [ "ort-sys/std", "ndarray/std", "tracing?/std" ]
|
||||
training = [ "ort-sys/training" ]
|
||||
|
||||
fetch-models = [ "ureq", "sha2" ]
|
||||
ndarray = [ "dep:ndarray" ]
|
||||
half = [ "dep:half" ]
|
||||
tracing = [ "dep:tracing" ]
|
||||
|
||||
fetch-models = [ "std", "dep:ureq", "dep:sha2" ]
|
||||
download-binaries = [ "ort-sys/download-binaries" ]
|
||||
load-dynamic = [ "libloading", "ort-sys/load-dynamic" ]
|
||||
load-dynamic = [ "std", "libloading", "ort-sys/load-dynamic" ]
|
||||
copy-dylibs = [ "ort-sys/copy-dylibs" ]
|
||||
|
||||
alternative-backend = [ "ort-sys/disable-linking" ]
|
||||
@@ -87,14 +92,14 @@ cann = [ "ort-sys/cann" ]
|
||||
qnn = [ "ort-sys/qnn" ]
|
||||
|
||||
[dependencies]
|
||||
ndarray = { version = "0.16", optional = true }
|
||||
ort-sys = { version = "=2.0.0-rc.9", path = "ort-sys" }
|
||||
ndarray = { version = "0.16", default-features = false, optional = true }
|
||||
ort-sys = { version = "=2.0.0-rc.9", path = "ort-sys", default-features = false }
|
||||
libloading = { version = "0.8", optional = true }
|
||||
|
||||
ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] }
|
||||
sha2 = { version = "0.10", optional = true }
|
||||
tracing = { version = "0.1", optional = true, default-features = false, features = [ "std" ] }
|
||||
half = { version = "2.1", optional = true }
|
||||
tracing = { version = "0.1", optional = true, default-features = false }
|
||||
half = { version = "2.1", default-features = false, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0"
|
||||
|
||||
@@ -16,7 +16,8 @@ authors = [
|
||||
include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = [ "std" ]
|
||||
std = []
|
||||
training = []
|
||||
download-binaries = [ "ureq", "tar", "flate2", "sha2" ]
|
||||
load-dynamic = []
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#[cfg(feature = "download-binaries")]
|
||||
use std::fs;
|
||||
use std::{
|
||||
env, fs,
|
||||
env,
|
||||
path::{Path, PathBuf}
|
||||
};
|
||||
|
||||
@@ -14,10 +16,13 @@ const ENV_CXXSTDLIB: &str = "CXXSTDLIB"; // Used by the `cc` crate - we should m
|
||||
#[cfg(feature = "download-binaries")]
|
||||
const ORT_EXTRACT_DIR: &str = "onnxruntime";
|
||||
|
||||
#[cfg(feature = "download-binaries")]
|
||||
const DIST_TABLE: &str = include_str!("dist.txt");
|
||||
|
||||
#[path = "src/internal/mod.rs"]
|
||||
#[cfg(feature = "download-binaries")]
|
||||
mod internal;
|
||||
#[cfg(feature = "download-binaries")]
|
||||
use self::internal::dirs::cache_dir;
|
||||
|
||||
#[cfg(feature = "download-binaries")]
|
||||
@@ -43,6 +48,7 @@ fn fetch_file(source_url: &str) -> Vec<u8> {
|
||||
buffer
|
||||
}
|
||||
|
||||
#[cfg(feature = "download-binaries")]
|
||||
fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> {
|
||||
DIST_TABLE
|
||||
.split('\n')
|
||||
@@ -63,7 +69,7 @@ fn hex_str_to_bytes(c: impl AsRef<[u8]>) -> Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
c.as_ref().chunks(2).map(|n| nibble(n[0]) << 4 | nibble(n[1])).collect()
|
||||
c.as_ref().chunks(2).map(|n| (nibble(n[0]) << 4) | nibble(n[1])).collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "download-binaries")]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// based on https://github.com/dirs-dev/dirs-sys-rs/blob/main/src/lib.rs
|
||||
|
||||
#![allow(unused)]
|
||||
|
||||
pub const PYKE_ROOT: &str = "ort.pyke.io";
|
||||
|
||||
#[cfg(all(target_os = "windows", target_arch = "x86"))]
|
||||
@@ -48,8 +50,8 @@ mod windows {
|
||||
pub const fn from_u128(uuid: u128) -> Self {
|
||||
Self {
|
||||
data1: (uuid >> 96) as u32,
|
||||
data2: (uuid >> 80 & 0xffff) as u16,
|
||||
data3: (uuid >> 64 & 0xffff) as u16,
|
||||
data2: ((uuid >> 80) & 0xffff) as u16,
|
||||
data3: ((uuid >> 64) & 0xffff) as u16,
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
data4: (uuid as u64).to_be_bytes()
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,11 @@
|
||||
//! An input adapter, allowing for loading many static inputs from disk at once.
|
||||
|
||||
use std::{
|
||||
path::Path,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
};
|
||||
use alloc::sync::Arc;
|
||||
use core::ptr::{self, NonNull};
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{AsPointer, Result, memory::Allocator, ortsys, util};
|
||||
use crate::{AsPointer, Result, memory::Allocator, ortsys};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AdapterInner {
|
||||
@@ -88,8 +87,10 @@ impl Adapter {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn from_file(path: impl AsRef<Path>, allocator: Option<&Allocator>) -> Result<Self> {
|
||||
let path = util::path_to_os_char(path);
|
||||
let path = crate::util::path_to_os_char(path);
|
||||
let allocator_ptr = allocator.map(|c| c.ptr().cast_mut()).unwrap_or_else(ptr::null_mut);
|
||||
let mut ptr = ptr::null_mut();
|
||||
ortsys![unsafe CreateLoraAdapter(path.as_ptr(), allocator_ptr, &mut ptr)?];
|
||||
@@ -156,8 +157,6 @@ impl AsPointer for Adapter {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
|
||||
use super::Adapter;
|
||||
use crate::{
|
||||
session::{RunOptions, Session},
|
||||
@@ -166,13 +165,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_lora() -> crate::Result<()> {
|
||||
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||
let lora = Adapter::from_file("tests/data/adapter.orl", None)?;
|
||||
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
|
||||
let session = Session::builder()?.commit_from_memory(&model)?;
|
||||
let lora = std::fs::read("tests/data/adapter.orl").expect("");
|
||||
let lora = Adapter::from_memory(&lora, None)?;
|
||||
|
||||
let mut run_options = RunOptions::new()?;
|
||||
run_options.add_adapter(&lora)?;
|
||||
|
||||
let output: Tensor<f32> = model
|
||||
let output: Tensor<f32> = session
|
||||
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?
|
||||
.remove("output")
|
||||
.expect("")
|
||||
@@ -188,16 +189,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_lora_from_memory() -> crate::Result<()> {
|
||||
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
|
||||
let session = Session::builder()?.commit_from_memory(&model)?;
|
||||
|
||||
let lora_bytes = fs::read("tests/data/adapter.orl").expect("");
|
||||
let lora_bytes = std::fs::read("tests/data/adapter.orl").expect("");
|
||||
let lora = Adapter::from_memory(&lora_bytes, None)?;
|
||||
drop(lora_bytes);
|
||||
|
||||
let mut run_options = RunOptions::new()?;
|
||||
run_options.add_adapter(&lora)?;
|
||||
|
||||
let output: Tensor<f32> = model
|
||||
let output: Tensor<f32> = session
|
||||
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?
|
||||
.remove("output")
|
||||
.expect("")
|
||||
|
||||
@@ -11,25 +11,18 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use std::{
|
||||
use alloc::{boxed::Box, ffi::CString, string::String, vec::Vec};
|
||||
use core::{
|
||||
any::Any,
|
||||
ffi::CString,
|
||||
os::raw::c_void,
|
||||
ptr::{self, NonNull},
|
||||
sync::{Arc, RwLock}
|
||||
ffi::c_void,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
use crate::G_ORT_DYLIB_PATH;
|
||||
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys};
|
||||
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys, util::OnceLock};
|
||||
|
||||
struct EnvironmentSingleton {
|
||||
lock: RwLock<Option<Arc<Environment>>>
|
||||
}
|
||||
|
||||
unsafe impl Sync for EnvironmentSingleton {}
|
||||
|
||||
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(None) };
|
||||
static G_ENV: OnceLock<Environment> = OnceLock::new();
|
||||
|
||||
/// An `Environment` is a process-global structure, under which [`Session`](crate::session::Session)s are created.
|
||||
///
|
||||
@@ -39,9 +32,7 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(No
|
||||
/// environments are also used to configure ONNX Runtime to send log messages through the [`tracing`] crate in Rust.
|
||||
///
|
||||
/// For ease of use, and since sessions require an environment to be created, `ort` will automatically create an
|
||||
/// environment if one is not configured via [`init`] (or [`init_from`]). [`init`] can be called at any point in the
|
||||
/// program (even after an environment has been automatically created), though every session created before the
|
||||
/// re-configuration would need to be re-created in order to use the config from the new environment.
|
||||
/// environment if one is not configured via [`init`] (or [`init_from`]).
|
||||
#[derive(Debug)]
|
||||
pub struct Environment {
|
||||
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
|
||||
@@ -70,17 +61,11 @@ impl Drop for Environment {
|
||||
|
||||
/// Gets a reference to the global environment, creating one if an environment has not been
|
||||
/// [`commit`](EnvironmentBuilder::commit)ted yet.
|
||||
pub fn get_environment() -> Result<Arc<Environment>> {
|
||||
let env = G_ENV.lock.read().expect("poisoned lock");
|
||||
if let Some(env) = env.as_ref() {
|
||||
Ok(Arc::clone(env))
|
||||
} else {
|
||||
// drop our read lock so we dont deadlock when `commit` takes a write lock
|
||||
drop(env);
|
||||
|
||||
pub fn get_environment() -> Result<&'static Environment> {
|
||||
G_ENV.get_or_try_init(|| {
|
||||
crate::debug!("Environment not yet initialized, creating a new one");
|
||||
Ok(EnvironmentBuilder::new().commit()?)
|
||||
}
|
||||
EnvironmentBuilder::new().commit_internal()
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -179,10 +164,14 @@ pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
|
||||
worker: ort_thread_worker_fn
|
||||
};
|
||||
|
||||
let res = std::panic::catch_unwind(|| {
|
||||
let runner = || {
|
||||
let manager = unsafe { &mut *ort_custom_thread_creation_options.cast::<T>() };
|
||||
<T as ThreadManager>::create(manager, thread_worker)
|
||||
});
|
||||
};
|
||||
#[cfg(not(feature = "std"))]
|
||||
let res = Result::<_, crate::Error>::Ok(runner()); // dumb hack
|
||||
#[cfg(feature = "std")]
|
||||
let res = std::panic::catch_unwind(runner);
|
||||
match res {
|
||||
Ok(Ok(thread)) => (Box::leak(Box::new(thread)) as *mut <T as ThreadManager>::Thread)
|
||||
.cast_const()
|
||||
@@ -219,9 +208,9 @@ pub struct EnvironmentBuilder {
|
||||
impl EnvironmentBuilder {
|
||||
pub(crate) fn new() -> Self {
|
||||
EnvironmentBuilder {
|
||||
name: "default".to_string(),
|
||||
name: String::from("default"),
|
||||
telemetry: true,
|
||||
execution_providers: vec![],
|
||||
execution_providers: Vec::new(),
|
||||
global_thread_pool_options: None
|
||||
}
|
||||
}
|
||||
@@ -275,10 +264,9 @@ impl EnvironmentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Commit the environment configuration and set the global environment.
|
||||
pub fn commit(self) -> Result<Arc<Environment>> {
|
||||
pub(crate) fn commit_internal(self) -> Result<Environment> {
|
||||
let (env_ptr, thread_manager, has_global_threadpool) = if let Some(mut thread_pool_options) = self.global_thread_pool_options {
|
||||
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
|
||||
let mut env_ptr: *mut ort_sys::OrtEnv = ptr::null_mut();
|
||||
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
@@ -307,7 +295,7 @@ impl EnvironmentBuilder {
|
||||
let thread_manager = thread_pool_options.thread_manager.take();
|
||||
(env_ptr, thread_manager, true)
|
||||
} else {
|
||||
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
|
||||
let mut env_ptr: *mut ort_sys::OrtEnv = ptr::null_mut();
|
||||
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
@@ -333,7 +321,7 @@ impl EnvironmentBuilder {
|
||||
|
||||
(env_ptr, None, false)
|
||||
};
|
||||
crate::debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");
|
||||
crate::debug!(env_ptr = alloc::format!("{env_ptr:?}").as_str(), "Environment created");
|
||||
|
||||
if self.telemetry {
|
||||
ortsys![unsafe EnableTelemetryEvents(env_ptr)?];
|
||||
@@ -341,21 +329,19 @@ impl EnvironmentBuilder {
|
||||
ortsys![unsafe DisableTelemetryEvents(env_ptr)?];
|
||||
}
|
||||
|
||||
let mut env_lock = G_ENV.lock.write().expect("poisoned lock");
|
||||
// drop global reference to previous environment
|
||||
if let Some(env_arc) = env_lock.take() {
|
||||
drop(env_arc);
|
||||
}
|
||||
let env = Arc::new(Environment {
|
||||
Ok(Environment {
|
||||
execution_providers: self.execution_providers,
|
||||
// we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
|
||||
ptr: unsafe { NonNull::new_unchecked(env_ptr) },
|
||||
has_global_threadpool,
|
||||
_thread_manager: thread_manager
|
||||
});
|
||||
env_lock.replace(Arc::clone(&env));
|
||||
})
|
||||
}
|
||||
|
||||
Ok(env)
|
||||
/// Commit the environment configuration.
|
||||
pub fn commit(self) -> Result<bool> {
|
||||
let env = self.commit_internal()?;
|
||||
Ok(G_ENV.try_insert(env))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,6 +395,6 @@ pub fn init() -> EnvironmentBuilder {
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
|
||||
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
|
||||
let _ = G_ORT_DYLIB_PATH.get_or_init(|| alloc::sync::Arc::new(path.to_string()));
|
||||
EnvironmentBuilder::new()
|
||||
}
|
||||
|
||||
31
src/error.rs
31
src/error.rs
@@ -1,9 +1,14 @@
|
||||
use std::{convert::Infallible, ffi::CString, fmt, ptr};
|
||||
use alloc::{
|
||||
ffi::CString,
|
||||
format,
|
||||
string::{String, ToString}
|
||||
};
|
||||
use core::{convert::Infallible, ffi::c_char, fmt, ptr};
|
||||
|
||||
use crate::{char_p_to_string, ortsys};
|
||||
|
||||
/// Type alias for the Result type returned by ORT functions.
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||
|
||||
pub(crate) trait IntoStatus {
|
||||
fn into_status(self) -> *mut ort_sys::OrtStatus;
|
||||
@@ -17,7 +22,7 @@ impl<T> IntoStatus for Result<T, Error> {
|
||||
};
|
||||
let message = message.map(|c| CString::new(c).unwrap_or_else(|_| unreachable!()));
|
||||
// message will be copied, so this shouldn't leak
|
||||
ortsys![unsafe CreateStatus(code, message.map(|c| c.as_ptr()).unwrap_or_else(std::ptr::null))]
|
||||
ortsys![unsafe CreateStatus(code, message.map(|c| c.as_ptr()).unwrap_or_else(ptr::null))]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +38,7 @@ impl Error {
|
||||
///
|
||||
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
|
||||
/// related operation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn wrap<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
|
||||
Error {
|
||||
code: ErrorCode::GenericFailure,
|
||||
@@ -40,6 +46,18 @@ impl Error {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a custom, user-provided error in an [`ort::Error`](Error)..
|
||||
///
|
||||
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
|
||||
/// related operation fails.
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub fn wrap<T: core::fmt::Display + Send + Sync + 'static>(err: T) -> Self {
|
||||
Error {
|
||||
code: ErrorCode::GenericFailure,
|
||||
msg: err.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a custom [`Error`] with the given message.
|
||||
pub fn new(msg: impl Into<String>) -> Self {
|
||||
Error {
|
||||
@@ -68,6 +86,7 @@ impl fmt::Display for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")] // sigh...
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
@@ -76,8 +95,8 @@ impl From<Infallible> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::ffi::NulError> for Error {
|
||||
fn from(e: std::ffi::NulError) -> Self {
|
||||
impl From<alloc::ffi::NulError> for Error {
|
||||
fn from(e: alloc::ffi::NulError) -> Self {
|
||||
Error::new(format!("Attempted to pass invalid string to C: {e}"))
|
||||
}
|
||||
}
|
||||
@@ -153,7 +172,7 @@ pub(crate) unsafe fn status_to_result(status: *mut ort_sys::OrtStatus) -> Result
|
||||
Ok(())
|
||||
} else {
|
||||
let code = ErrorCode::from(ortsys![unsafe GetErrorCode(status)]);
|
||||
let raw: *const std::os::raw::c_char = ortsys![unsafe GetErrorMessage(status)];
|
||||
let raw: *const c_char = ortsys![unsafe GetErrorMessage(status)];
|
||||
match char_p_to_string(raw) {
|
||||
Ok(msg) => {
|
||||
ortsys![unsafe ReleaseStatus(status)];
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
@@ -6,7 +8,7 @@ use crate::{
|
||||
|
||||
#[cfg(all(not(feature = "load-dynamic"), feature = "acl"))]
|
||||
extern "C" {
|
||||
fn OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr;
|
||||
fn OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
@@ -48,7 +50,7 @@ impl ExecutionProvider for ACLExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||
return unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.use_arena.into())) };
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
@@ -6,7 +8,7 @@ use crate::{
|
||||
|
||||
#[cfg(all(not(feature = "load-dynamic"), feature = "armnn"))]
|
||||
extern "C" {
|
||||
fn OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr;
|
||||
fn OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
@@ -48,7 +50,7 @@ impl ExecutionProvider for ArmNNExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||
return unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into())) };
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{format, string::ToString};
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -144,7 +146,7 @@ impl ExecutionProvider for CANNExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
let mut cann_options: *mut ort_sys::OrtCANNProviderOptions = std::ptr::null_mut();
|
||||
let mut cann_options: *mut ort_sys::OrtCANNProviderOptions = core::ptr::null_mut();
|
||||
crate::ortsys![unsafe CreateCANNProviderOptions(&mut cann_options)?];
|
||||
let ffi_options = self.options.to_ffi();
|
||||
if let Err(e) = unsafe {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::ops::BitOr;
|
||||
use alloc::{format, string::ToString};
|
||||
use core::ops::BitOr;
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
@@ -269,7 +270,7 @@ impl ExecutionProvider for CUDAExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
let mut cuda_options: *mut ort_sys::OrtCUDAProviderOptionsV2 = std::ptr::null_mut();
|
||||
let mut cuda_options: *mut ort_sys::OrtCUDAProviderOptionsV2 = core::ptr::null_mut();
|
||||
crate::ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)?];
|
||||
let ffi_options = self.options.to_ffi();
|
||||
if let Err(e) = unsafe {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
@@ -6,7 +8,7 @@ use crate::{
|
||||
|
||||
#[cfg(all(not(feature = "load-dynamic"), feature = "directml"))]
|
||||
extern "C" {
|
||||
fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut ort_sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> ort_sys::OrtStatusPtr;
|
||||
fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut ort_sys::OrtSessionOptions, device_id: core::ffi::c_int) -> ort_sys::OrtStatusPtr;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
@@ -48,7 +50,7 @@ impl ExecutionProvider for DirectMLExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_DML(options: *mut ort_sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_DML(options: *mut ort_sys::OrtSessionOptions, device_id: core::ffi::c_int) -> ort_sys::OrtStatusPtr);
|
||||
return unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_DML(session_builder.ptr_mut(), self.device_id as _)) };
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::ffi::CString;
|
||||
use alloc::{ffi::CString, format};
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -87,6 +87,8 @@ impl ExecutionProvider for MIGraphXExecutionProvider {
|
||||
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
|
||||
#[cfg(any(feature = "load-dynamic", feature = "migraphx"))]
|
||||
{
|
||||
use core::ptr;
|
||||
|
||||
use crate::AsPointer;
|
||||
|
||||
let options = ort_sys::OrtMIGraphXProviderOptions {
|
||||
@@ -94,15 +96,11 @@ impl ExecutionProvider for MIGraphXExecutionProvider {
|
||||
migraphx_fp16_enable: self.enable_fp16.into(),
|
||||
migraphx_int8_enable: self.enable_int8.into(),
|
||||
migraphx_use_native_calibration_table: self.use_native_calibration_table.into(),
|
||||
migraphx_int8_calibration_table_name: self
|
||||
.int8_calibration_table_name
|
||||
.as_ref()
|
||||
.map(|c| c.as_ptr())
|
||||
.unwrap_or_else(std::ptr::null),
|
||||
migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
|
||||
migraphx_load_compiled_model: self.load_model_path.is_some().into(),
|
||||
migraphx_load_model_path: self.load_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(std::ptr::null),
|
||||
migraphx_load_model_path: self.load_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
|
||||
migraphx_save_compiled_model: self.save_model_path.is_some().into(),
|
||||
migraphx_save_model_path: self.save_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(std::ptr::null),
|
||||
migraphx_save_model_path: self.save_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
|
||||
migraphx_exhaustive_tune: self.exhaustive_tune
|
||||
};
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.ptr_mut(), &options)?];
|
||||
|
||||
@@ -14,9 +14,14 @@
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use std::{collections::HashMap, ffi::CString, fmt::Debug, os::raw::c_char, sync::Arc};
|
||||
use alloc::{ffi::CString, string::ToString, sync::Arc, vec::Vec};
|
||||
use core::{
|
||||
ffi::c_char,
|
||||
fmt::{self, Debug},
|
||||
ptr
|
||||
};
|
||||
|
||||
use crate::{char_p_to_string, error::Result, ortsys, session::builder::SessionBuilder};
|
||||
use crate::{char_p_to_string, error::Result, ortsys, session::builder::SessionBuilder, util::MiniMap};
|
||||
|
||||
pub mod cpu;
|
||||
pub use self::cpu::CPUExecutionProvider;
|
||||
@@ -94,7 +99,7 @@ pub trait ExecutionProvider: Send + Sync {
|
||||
/// enabled), you'll instead want to manually register this EP via [`ExecutionProvider::register`] and detect
|
||||
/// and handle any errors returned by that function.
|
||||
fn is_available(&self) -> Result<bool> {
|
||||
let mut providers: *mut *mut c_char = std::ptr::null_mut();
|
||||
let mut providers: *mut *mut c_char = ptr::null_mut();
|
||||
let mut num_providers = 0;
|
||||
ortsys![unsafe GetAvailableProviders(&mut providers, &mut num_providers)?];
|
||||
if providers.is_null() {
|
||||
@@ -180,7 +185,7 @@ impl ExecutionProviderDispatch {
|
||||
}
|
||||
|
||||
impl Debug for ExecutionProviderDispatch {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct(self.inner.as_str())
|
||||
.field("error_on_failure", &self.error_on_failure)
|
||||
.finish()
|
||||
@@ -188,7 +193,7 @@ impl Debug for ExecutionProviderDispatch {
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub(crate) struct ExecutionProviderOptions(HashMap<CString, CString>);
|
||||
pub(crate) struct ExecutionProviderOptions(MiniMap<CString, CString>);
|
||||
|
||||
impl ExecutionProviderOptions {
|
||||
pub fn set(&mut self, key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) {
|
||||
@@ -231,14 +236,14 @@ macro_rules! get_ep_register {
|
||||
#[allow(non_snake_case)]
|
||||
let $symbol = unsafe {
|
||||
let dylib = $crate::lib_handle();
|
||||
let symbol: ::std::result::Result<
|
||||
let symbol: ::core::result::Result<
|
||||
::libloading::Symbol<unsafe extern "C" fn($($id: $type),*) -> $rt>,
|
||||
::libloading::Error
|
||||
> = dylib.get(stringify!($symbol).as_bytes());
|
||||
match symbol {
|
||||
Ok(symbol) => symbol.into_raw(),
|
||||
Err(e) => {
|
||||
return ::std::result::Result::Err($crate::Error::new(format!("Error attempting to load symbol `{}` from dynamic library: {}", stringify!($symbol), e)));
|
||||
return ::core::result::Result::Err($crate::Error::new(format!("Error attempting to load symbol `{}` from dynamic library: {}", stringify!($symbol), e)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{format, string::ToString};
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -51,7 +53,7 @@ impl ExecutionProvider for OneDNNExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
let mut dnnl_options: *mut ort_sys::OrtDnnlProviderOptions = std::ptr::null_mut();
|
||||
let mut dnnl_options: *mut ort_sys::OrtDnnlProviderOptions = core::ptr::null_mut();
|
||||
crate::ortsys![unsafe CreateDnnlProviderOptions(&mut dnnl_options)?];
|
||||
let ffi_options = self.options.to_ffi();
|
||||
if let Err(e) = unsafe {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::os::raw::c_void;
|
||||
use alloc::{ffi::CString, format};
|
||||
use core::{ffi::c_void, ptr};
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -8,10 +9,10 @@ use crate::{
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpenVINOExecutionProvider {
|
||||
device_type: Option<String>,
|
||||
device_id: Option<String>,
|
||||
device_type: Option<CString>,
|
||||
device_id: Option<CString>,
|
||||
num_threads: usize,
|
||||
cache_dir: Option<String>,
|
||||
cache_dir: Option<CString>,
|
||||
context: *mut c_void,
|
||||
enable_opencl_throttling: bool,
|
||||
enable_dynamic_shapes: bool,
|
||||
@@ -28,7 +29,7 @@ impl Default for OpenVINOExecutionProvider {
|
||||
device_id: None,
|
||||
num_threads: 8,
|
||||
cache_dir: None,
|
||||
context: std::ptr::null_mut(),
|
||||
context: ptr::null_mut(),
|
||||
enable_opencl_throttling: false,
|
||||
enable_dynamic_shapes: false,
|
||||
enable_npu_fast_compile: false
|
||||
@@ -40,16 +41,16 @@ impl OpenVINOExecutionProvider {
|
||||
/// Overrides the accelerator hardware type and precision with these values at runtime. If this option is not
|
||||
/// explicitly set, default hardware and precision specified during build time is used.
|
||||
#[must_use]
|
||||
pub fn with_device_type(mut self, device_type: impl ToString) -> Self {
|
||||
self.device_type = Some(device_type.to_string());
|
||||
pub fn with_device_type(mut self, device_type: impl AsRef<str>) -> Self {
|
||||
self.device_type = Some(CString::new(device_type.as_ref()).expect("invalid string"));
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects a particular hardware device for inference. If this option is not explicitly set, an arbitrary free
|
||||
/// device will be automatically selected by OpenVINO runtime.
|
||||
#[must_use]
|
||||
pub fn with_device_id(mut self, device_id: impl ToString) -> Self {
|
||||
self.device_id = Some(device_id.to_string());
|
||||
pub fn with_device_id(mut self, device_id: impl AsRef<str>) -> Self {
|
||||
self.device_id = Some(CString::new(device_id.as_ref()).expect("invalid string"));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -63,8 +64,8 @@ impl OpenVINOExecutionProvider {
|
||||
|
||||
/// Explicitly specify the path to save and load the blobs, enabling model caching.
|
||||
#[must_use]
|
||||
pub fn with_cache_dir(mut self, dir: impl ToString) -> Self {
|
||||
self.cache_dir = Some(dir.to_string());
|
||||
pub fn with_cache_dir(mut self, dir: impl AsRef<str>) -> Self {
|
||||
self.cache_dir = Some(CString::new(dir.as_ref()).expect("invalid string"));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -123,33 +124,28 @@ impl ExecutionProvider for OpenVINOExecutionProvider {
|
||||
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
|
||||
#[cfg(any(feature = "load-dynamic", feature = "openvino"))]
|
||||
{
|
||||
use std::ffi::CString;
|
||||
use alloc::ffi::CString;
|
||||
use core::ffi::c_char;
|
||||
|
||||
use crate::AsPointer;
|
||||
|
||||
// Like TensorRT, the OpenVINO EP is also pretty picky about needing an environment by this point.
|
||||
let _ = crate::environment::get_environment();
|
||||
|
||||
let device_type = self.device_type.as_deref().map(CString::new).transpose()?;
|
||||
let device_id = self.device_id.as_deref().map(CString::new).transpose()?;
|
||||
let cache_dir = self.cache_dir.as_deref().map(CString::new).transpose()?;
|
||||
let openvino_options = ort_sys::OrtOpenVINOProviderOptions {
|
||||
device_type: device_type
|
||||
device_type: self
|
||||
.device_type
|
||||
.as_ref()
|
||||
.map_or_else(std::ptr::null, |x| x.as_bytes().as_ptr().cast::<std::ffi::c_char>()),
|
||||
device_id: device_id
|
||||
.as_ref()
|
||||
.map_or_else(std::ptr::null, |x| x.as_bytes().as_ptr().cast::<std::ffi::c_char>()),
|
||||
.map_or_else(ptr::null, |x| x.as_bytes().as_ptr().cast::<c_char>()),
|
||||
device_id: self.device_id.as_ref().map_or_else(ptr::null, |x| x.as_bytes().as_ptr().cast::<c_char>()),
|
||||
num_of_threads: self.num_threads,
|
||||
cache_dir: cache_dir
|
||||
.as_ref()
|
||||
.map_or_else(std::ptr::null, |x| x.as_bytes().as_ptr().cast::<std::ffi::c_char>()),
|
||||
cache_dir: self.cache_dir.as_ref().map_or_else(ptr::null, |x| x.as_bytes().as_ptr().cast::<c_char>()),
|
||||
context: self.context,
|
||||
enable_opencl_throttling: self.enable_opencl_throttling.into(),
|
||||
enable_dynamic_shapes: self.enable_dynamic_shapes.into(),
|
||||
enable_npu_fast_compile: self.enable_npu_fast_compile.into()
|
||||
};
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_builder.ptr_mut(), std::ptr::addr_of!(openvino_options))?];
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_builder.ptr_mut(), ptr::addr_of!(openvino_options))?];
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{format, string::ToString};
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -206,10 +208,9 @@ impl ExecutionProvider for QNNExecutionProvider {
|
||||
use crate::AsPointer;
|
||||
|
||||
let ffi_options = self.options.to_ffi();
|
||||
let ep_name = std::ffi::CString::new("QNN").unwrap_or_else(|_| unreachable!());
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider(
|
||||
session_builder.ptr_mut(),
|
||||
ep_name.as_ptr(),
|
||||
b"QNN\0".as_ptr().cast::<core::ffi::c_char>(),
|
||||
ffi_options.key_ptrs(),
|
||||
ffi_options.value_ptrs(),
|
||||
ffi_options.len(),
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::os::raw::c_void;
|
||||
use alloc::format;
|
||||
use core::ffi::c_void;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -134,6 +135,8 @@ impl ExecutionProvider for ROCmExecutionProvider {
|
||||
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
|
||||
#[cfg(any(feature = "load-dynamic", feature = "rocm"))]
|
||||
{
|
||||
use core::ptr;
|
||||
|
||||
use crate::AsPointer;
|
||||
|
||||
let rocm_options = ort_sys::OrtROCMProviderOptions {
|
||||
@@ -146,14 +149,14 @@ impl ExecutionProvider for ROCmExecutionProvider {
|
||||
},
|
||||
do_copy_in_default_stream: self.do_copy_in_default_stream.into(),
|
||||
has_user_compute_stream: self.user_compute_stream.is_some().into(),
|
||||
user_compute_stream: self.user_compute_stream.unwrap_or_else(std::ptr::null_mut),
|
||||
default_memory_arena_cfg: self.default_memory_arena_cfg.unwrap_or_else(std::ptr::null_mut),
|
||||
user_compute_stream: self.user_compute_stream.unwrap_or_else(ptr::null_mut),
|
||||
default_memory_arena_cfg: self.default_memory_arena_cfg.unwrap_or_else(ptr::null_mut),
|
||||
enable_hip_graph: self.enable_hip_graph.into(),
|
||||
tunable_op_enable: self.tunable_op_enable.into(),
|
||||
tunable_op_tuning_enable: self.tunable_op_tuning_enable.into(),
|
||||
tunable_op_max_tuning_duration_ms: self.tunable_op_max_tuning_duration_ms
|
||||
};
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_ROCM(session_builder.ptr_mut(), std::ptr::addr_of!(rocm_options))?];
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_ROCM(session_builder.ptr_mut(), ptr::addr_of!(rocm_options))?];
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{format, string::ToString};
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
@@ -292,7 +294,7 @@ impl ExecutionProvider for TensorRTExecutionProvider {
|
||||
// environment initialized.
|
||||
let _ = crate::environment::get_environment();
|
||||
|
||||
let mut trt_options: *mut ort_sys::OrtTensorRTProviderOptionsV2 = std::ptr::null_mut();
|
||||
let mut trt_options: *mut ort_sys::OrtTensorRTProviderOptionsV2 = core::ptr::null_mut();
|
||||
crate::ortsys![unsafe CreateTensorRTProviderOptions(&mut trt_options)?];
|
||||
let ffi_options = self.options.to_ffi();
|
||||
if let Err(e) = unsafe {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{format, string::String};
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
|
||||
@@ -6,7 +8,7 @@ use crate::{
|
||||
|
||||
#[cfg(all(not(feature = "load-dynamic"), feature = "tvm"))]
|
||||
extern "C" {
|
||||
fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const std::os::raw::c_char) -> ort_sys::OrtStatusPtr;
|
||||
fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const core::ffi::c_char) -> ort_sys::OrtStatusPtr;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -72,7 +74,7 @@ impl ExecutionProvider for TVMExecutionProvider {
|
||||
{
|
||||
use crate::AsPointer;
|
||||
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const std::os::raw::c_char) -> ort_sys::OrtStatusPtr);
|
||||
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const core::ffi::c_char) -> ort_sys::OrtStatusPtr);
|
||||
let mut option_string = Vec::new();
|
||||
if let Some(check_hash) = self.check_hash {
|
||||
option_string.push(format!("check_hash:{}", if check_hash { "True" } else { "False" }));
|
||||
@@ -110,7 +112,7 @@ impl ExecutionProvider for TVMExecutionProvider {
|
||||
if let Some(to_nhwc) = self.to_nhwc {
|
||||
option_string.push(format!("to_nhwc:{}", if to_nhwc { "True" } else { "False" }));
|
||||
}
|
||||
let options_string = std::ffi::CString::new(option_string.join(",")).unwrap_or_else(|_| unreachable!());
|
||||
let options_string = alloc::ffi::CString::new(option_string.join(",")).expect("invalid option string");
|
||||
return unsafe { crate::error::status_to_result(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr())) };
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{format, string::ToString};
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::num::NonZeroUsize;
|
||||
use alloc::{format, string::ToString};
|
||||
use core::num::NonZeroUsize;
|
||||
|
||||
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
|
||||
use crate::{
|
||||
@@ -54,10 +55,9 @@ impl ExecutionProvider for XNNPACKExecutionProvider {
|
||||
use crate::AsPointer;
|
||||
|
||||
let ffi_options = self.options.to_ffi();
|
||||
let ep_name = std::ffi::CString::new("XNNPACK").unwrap_or_else(|_| unreachable!());
|
||||
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider(
|
||||
session_builder.ptr_mut(),
|
||||
ep_name.as_ptr(),
|
||||
b"XNNPACK\0".as_ptr().cast::<core::ffi::c_char>(),
|
||||
ffi_options.key_ptrs(),
|
||||
ffi_options.value_ptrs(),
|
||||
ffi_options.len(),
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
//! Enables binding of session inputs and/or outputs to pre-allocated memory.
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
use alloc::{
|
||||
ffi::CString,
|
||||
string::{String, ToString},
|
||||
sync::Arc,
|
||||
vec::Vec
|
||||
};
|
||||
use core::{
|
||||
fmt::Debug,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
slice
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -14,6 +18,7 @@ use crate::{
|
||||
memory::MemoryInfo,
|
||||
ortsys,
|
||||
session::{NoSelectedOutputs, RunOptions, Session, SharedSessionInner, output::SessionOutputs},
|
||||
util::MiniMap,
|
||||
value::{DynValue, Value, ValueInner, ValueTypeMarker}
|
||||
};
|
||||
|
||||
@@ -93,9 +98,8 @@ use crate::{
|
||||
#[derive(Debug)]
|
||||
pub struct IoBinding {
|
||||
ptr: NonNull<ort_sys::OrtIoBinding>,
|
||||
held_inputs: HashMap<String, Arc<ValueInner>>,
|
||||
output_names: Vec<String>,
|
||||
output_values: HashMap<String, DynValue>,
|
||||
held_inputs: MiniMap<String, Arc<ValueInner>>,
|
||||
output_values: MiniMap<String, Option<DynValue>>,
|
||||
session: Arc<SharedSessionInner>
|
||||
}
|
||||
|
||||
@@ -106,9 +110,8 @@ impl IoBinding {
|
||||
Ok(Self {
|
||||
ptr: unsafe { NonNull::new_unchecked(ptr) },
|
||||
session: session.inner(),
|
||||
held_inputs: HashMap::new(),
|
||||
output_names: Vec::new(),
|
||||
output_values: HashMap::new()
|
||||
held_inputs: MiniMap::new(),
|
||||
output_values: MiniMap::new()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -144,10 +147,7 @@ impl IoBinding {
|
||||
let name = name.as_ref();
|
||||
let cname = CString::new(name)?;
|
||||
ortsys![unsafe BindOutput(self.ptr_mut(), cname.as_ptr(), ort_value.ptr())?];
|
||||
self.output_names.push(name.to_string());
|
||||
// Clear the old bound output if we have any.
|
||||
drop(self.output_values.remove(name));
|
||||
self.output_values.insert(name.to_string(), ort_value.into_dyn());
|
||||
self.output_values.insert(name.to_string(), Some(ort_value.into_dyn()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ impl IoBinding {
|
||||
let name = name.as_ref();
|
||||
let cname = CString::new(name)?;
|
||||
ortsys![unsafe BindOutputToDevice(self.ptr_mut(), cname.as_ptr(), mem_info.ptr())?];
|
||||
self.output_names.push(name.to_string());
|
||||
self.output_values.insert(name.to_string(), None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -168,7 +168,6 @@ impl IoBinding {
|
||||
/// Clears all bound outputs specified by [`IoBinding::bind_output`] or [`IoBinding::bind_output_to_device`].
|
||||
pub fn clear_outputs(&mut self) {
|
||||
ortsys![unsafe ClearBoundOutputs(self.ptr_mut())];
|
||||
drop(self.output_names.drain(..));
|
||||
drop(self.output_values.drain());
|
||||
}
|
||||
/// Clears both the bound inputs & outputs; equivalent to [`IoBinding::clear_inputs`] followed by
|
||||
@@ -210,23 +209,25 @@ impl IoBinding {
|
||||
}
|
||||
|
||||
fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, '_>> {
|
||||
let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { std::ptr::null() };
|
||||
let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { ptr::null() };
|
||||
ortsys![unsafe RunWithBinding(self.session.ptr().cast_mut(), run_options_ptr, self.ptr())?];
|
||||
|
||||
let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), c)).collect();
|
||||
let mut count = self.output_names.len();
|
||||
// let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(),
|
||||
// c)).collect();
|
||||
let mut count = self.output_values.len();
|
||||
if count > 0 {
|
||||
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
|
||||
ortsys![unsafe GetBoundOutputValues(self.ptr(), self.session.allocator.ptr().cast_mut(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];
|
||||
|
||||
let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count).to_vec() }
|
||||
let output_values = unsafe { slice::from_raw_parts(output_values_ptr, count).to_vec() }
|
||||
.into_iter()
|
||||
.map(|v| unsafe {
|
||||
if let Some(value) = owned_ptrs.get(&v) {
|
||||
.zip(self.output_values.iter())
|
||||
.map(|(ptr, (_, value))| unsafe {
|
||||
if let Some(value) = value {
|
||||
DynValue::clone_of(value)
|
||||
} else {
|
||||
DynValue::from_ptr(
|
||||
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
|
||||
NonNull::new(ptr).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
|
||||
Some(Arc::clone(&self.session))
|
||||
)
|
||||
}
|
||||
@@ -236,7 +237,7 @@ impl IoBinding {
|
||||
// output values will be freed when the `Value`s in `SessionOutputs` drop
|
||||
|
||||
Ok(SessionOutputs::new_backed(
|
||||
self.output_names.iter().map(String::as_str).collect(),
|
||||
self.output_values.iter().map(|(k, _)| k.as_str()).collect(),
|
||||
output_values,
|
||||
&self.session.allocator,
|
||||
output_values_ptr.cast()
|
||||
@@ -265,19 +266,22 @@ impl Drop for IoBinding {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cmp::Ordering;
|
||||
use core::cmp::Ordering;
|
||||
|
||||
use image::{ImageBuffer, Luma, Pixel};
|
||||
#[cfg(feature = "ndarray")]
|
||||
use ndarray::{Array2, Array4, Axis};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
use crate::tensor::ArrayExtensions;
|
||||
use crate::{
|
||||
Result,
|
||||
memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType},
|
||||
session::Session,
|
||||
tensor::ArrayExtensions,
|
||||
value::{Tensor, TensorValueTypeMarker, Value}
|
||||
};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
fn get_image() -> Array4<f32> {
|
||||
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open("tests/data/mnist_5.jpg").expect("failed to load image").to_luma8();
|
||||
ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| {
|
||||
@@ -287,6 +291,7 @@ mod tests {
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
fn extract_probabilities<T: TensorValueTypeMarker>(output: &Value<T>) -> Result<Vec<(usize, f32)>> {
|
||||
let mut probabilities: Vec<(usize, f32)> = output
|
||||
.try_extract_tensor()?
|
||||
@@ -301,6 +306,7 @@ mod tests {
|
||||
|
||||
// not terribly useful since CI is CPU-only, but it at least ensures the API won't segfault or something silly
|
||||
#[test]
|
||||
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
|
||||
fn test_mnist_input_bound() -> Result<()> {
|
||||
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;
|
||||
|
||||
@@ -318,6 +324,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
|
||||
fn test_mnist_input_output_bound() -> Result<()> {
|
||||
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;
|
||||
|
||||
@@ -337,6 +344,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
|
||||
fn test_send_iobinding() -> Result<()> {
|
||||
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;
|
||||
|
||||
@@ -361,6 +369,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
|
||||
fn test_mnist_clear_bounds() -> Result<()> {
|
||||
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;
|
||||
|
||||
|
||||
43
src/lib.rs
43
src/lib.rs
@@ -3,6 +3,7 @@
|
||||
#![allow(clippy::tabs_in_doc_comments, clippy::arc_with_non_send_sync)]
|
||||
#![allow(clippy::macro_metavars_in_unsafe)]
|
||||
#![warn(clippy::unwrap_used)]
|
||||
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
|
||||
|
||||
//! <div align=center>
|
||||
//! <img src="https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/docs/trend-banner.png" width="350px">
|
||||
@@ -12,8 +13,14 @@
|
||||
//! `ort` is a Rust binding for [ONNX Runtime](https://onnxruntime.ai/). For information on how to get started with `ort`,
|
||||
//! see <https://ort.pyke.io/introduction>.
|
||||
|
||||
#[cfg(all(test, not(feature = "fetch-models")))]
|
||||
compile_error!("`cargo test --features fetch-models`!!1!");
|
||||
extern crate alloc;
|
||||
extern crate core;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub mod __private {
|
||||
pub extern crate alloc;
|
||||
pub extern crate core;
|
||||
}
|
||||
|
||||
pub mod adapter;
|
||||
pub mod environment;
|
||||
@@ -35,14 +42,20 @@ pub(crate) mod util;
|
||||
pub mod value;
|
||||
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
use std::sync::Arc;
|
||||
use std::{ffi::CStr, os::raw::c_char, ptr::NonNull, sync::OnceLock};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::{borrow::ToOwned, boxed::Box, string::String};
|
||||
use core::{
|
||||
ffi::{CStr, c_char},
|
||||
ptr::NonNull,
|
||||
slice, str
|
||||
};
|
||||
|
||||
pub use ort_sys as sys;
|
||||
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
pub use self::environment::init_from;
|
||||
pub(crate) use self::logging::{debug, error, info, trace, warning as warn};
|
||||
use self::util::OnceLock;
|
||||
pub use self::{
|
||||
environment::init,
|
||||
error::{Error, ErrorCode, Result}
|
||||
@@ -105,7 +118,7 @@ pub fn info() -> &'static str {
|
||||
while unsafe { *str.add(len) } != 0x00 {
|
||||
len += 1;
|
||||
}
|
||||
unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(str.cast::<u8>(), len)) }
|
||||
unsafe { str::from_utf8_unchecked(slice::from_raw_parts(str.cast::<u8>(), len)) }
|
||||
}
|
||||
|
||||
struct ApiPointer(NonNull<ort_sys::OrtApi>);
|
||||
@@ -133,6 +146,8 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
.get_or_init(|| {
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
unsafe {
|
||||
use core::cmp::Ordering;
|
||||
|
||||
let dylib = lib_handle();
|
||||
let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> = dylib
|
||||
.get(b"OrtGetApiBase")
|
||||
@@ -146,17 +161,17 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
|
||||
let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
|
||||
match lib_minor_version.cmp(&MINOR_VERSION) {
|
||||
std::cmp::Ordering::Less => panic!(
|
||||
Ordering::Less => panic!(
|
||||
"ort {} is not compatible with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.{MINOR_VERSION}.x', but got '{version_string}'",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
dylib_path()
|
||||
),
|
||||
std::cmp::Ordering::Greater => crate::warn!(
|
||||
Ordering::Greater => crate::warn!(
|
||||
"ort {} may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.{MINOR_VERSION}.x', but got '{version_string}'",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
dylib_path()
|
||||
),
|
||||
std::cmp::Ordering::Equal => {}
|
||||
Ordering::Equal => {}
|
||||
};
|
||||
let api: *const ort_sys::OrtApi = ((*base).GetApi)(ort_sys::ORT_API_VERSION);
|
||||
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
|
||||
@@ -174,13 +189,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
}
|
||||
|
||||
pub fn set_api(api: ort_sys::OrtApi) -> bool {
|
||||
match G_ORT_API.set(ApiPointer(unsafe { NonNull::new_unchecked(Box::leak(Box::new(api))) })) {
|
||||
Ok(()) => true,
|
||||
Err(api) => {
|
||||
drop(unsafe { Box::from_raw(api.0.as_ptr()) });
|
||||
false
|
||||
}
|
||||
}
|
||||
G_ORT_API.try_insert(ApiPointer(unsafe { NonNull::new_unchecked(Box::leak(Box::new(api))) }))
|
||||
}
|
||||
|
||||
/// Trait to access raw pointers from safe types which wrap unsafe [`ort_sys`] types.
|
||||
@@ -246,12 +255,12 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {
|
||||
return Ok(String::new());
|
||||
}
|
||||
let c_string = unsafe { CStr::from_ptr(raw.cast_mut()).to_owned() };
|
||||
Ok(c_string.to_string_lossy().to_string())
|
||||
Ok(c_string.to_string_lossy().into())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::ffi::CString;
|
||||
use alloc::ffi::CString;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#[cfg(feature = "tracing")]
|
||||
use std::{
|
||||
use core::{
|
||||
ffi::{self, CStr},
|
||||
ptr
|
||||
};
|
||||
@@ -42,6 +42,7 @@ pub(crate) use warning;
|
||||
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
pub fn default_log_level() -> ort_sys::OrtLoggingLevel {
|
||||
#[cfg(feature = "std")]
|
||||
match std::env::var("ORT_LOG").as_deref() {
|
||||
Ok("fatal") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL,
|
||||
Ok("error") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
|
||||
@@ -50,6 +51,8 @@ pub fn default_log_level() -> ort_sys::OrtLoggingLevel {
|
||||
Ok("verbose") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
||||
_ => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR
|
||||
}
|
||||
#[cfg(not(feature = "std"))]
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING
|
||||
}
|
||||
|
||||
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
//! Types for managing memory & device allocations.
|
||||
|
||||
use std::{
|
||||
use alloc::sync::Arc;
|
||||
use core::{
|
||||
ffi::{c_char, c_int, c_void},
|
||||
mem,
|
||||
ptr::NonNull,
|
||||
sync::Arc
|
||||
ptr::{self, NonNull},
|
||||
slice, str
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -151,7 +152,7 @@ impl Allocator {
|
||||
/// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the
|
||||
/// [`MemoryInfo`].
|
||||
pub fn new(session: &Session, memory_info: MemoryInfo) -> Result<Self> {
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
|
||||
ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)];
|
||||
Ok(Self {
|
||||
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
|
||||
@@ -169,7 +170,7 @@ impl Default for Allocator {
|
||||
/// The allocator returned by this function is actually shared across all invocations (though this behavior is
|
||||
/// transparent to the user).
|
||||
fn default() -> Self {
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
|
||||
unsafe { status_to_result(ortsys![GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]) }
|
||||
.expect("Failed to get default allocator");
|
||||
Self {
|
||||
@@ -389,7 +390,7 @@ impl MemoryInfo {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result<Self> {
|
||||
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
|
||||
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = ptr::null_mut();
|
||||
ortsys![
|
||||
unsafe CreateMemoryInfo(allocation_device.as_str().as_ptr().cast(), allocator_type.into(), device_id, memory_type.into(), &mut memory_info_ptr)?;
|
||||
nonNull(memory_info_ptr)
|
||||
@@ -404,7 +405,7 @@ impl MemoryInfo {
|
||||
let mut is_tensor = 0;
|
||||
ortsys![unsafe IsTensor(value_ptr, &mut is_tensor)]; // infallible
|
||||
if is_tensor != 0 {
|
||||
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut();
|
||||
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = ptr::null_mut();
|
||||
// infallible, and `memory_info_ptr` will never be null
|
||||
ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr)];
|
||||
Some(Self::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false))
|
||||
@@ -465,7 +466,7 @@ impl MemoryInfo {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn allocation_device(&self) -> AllocationDevice {
|
||||
let mut name_ptr: *const c_char = std::ptr::null_mut();
|
||||
let mut name_ptr: *const c_char = ptr::null_mut();
|
||||
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)];
|
||||
|
||||
// SAFETY: `name_ptr` can never be null - `CreateMemoryInfo` internally checks against builtin device names, erroring
|
||||
@@ -478,7 +479,7 @@ impl MemoryInfo {
|
||||
|
||||
// SAFETY: ONNX Runtime internally only ever defines allocation device names as ASCII. can't wait for this to blow up
|
||||
// one day regardless
|
||||
let name = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(name_ptr.cast::<u8>(), len + 1)) };
|
||||
let name = unsafe { str::from_utf8_unchecked(slice::from_raw_parts(name_ptr.cast::<u8>(), len + 1)) };
|
||||
AllocationDevice(name)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
use std::{ffi::CString, os::raw::c_char, ptr::NonNull};
|
||||
use alloc::{ffi::CString, string::String, vec::Vec};
|
||||
use core::{
|
||||
ffi::c_char,
|
||||
ptr::{self, NonNull},
|
||||
slice
|
||||
};
|
||||
|
||||
use crate::{AsPointer, char_p_to_string, error::Result, memory::Allocator, ortsys};
|
||||
|
||||
@@ -15,7 +20,7 @@ impl<'s> ModelMetadata<'s> {
|
||||
|
||||
/// Gets the model description, returning an error if no description is present.
|
||||
pub fn description(&self) -> Result<String> {
|
||||
let mut str_bytes: *mut c_char = std::ptr::null_mut();
|
||||
let mut str_bytes: *mut c_char = ptr::null_mut();
|
||||
ortsys![unsafe ModelMetadataGetDescription(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)];
|
||||
|
||||
let value = match char_p_to_string(str_bytes) {
|
||||
@@ -31,7 +36,7 @@ impl<'s> ModelMetadata<'s> {
|
||||
|
||||
/// Gets the model producer name, returning an error if no producer name is present.
|
||||
pub fn producer(&self) -> Result<String> {
|
||||
let mut str_bytes: *mut c_char = std::ptr::null_mut();
|
||||
let mut str_bytes: *mut c_char = ptr::null_mut();
|
||||
ortsys![unsafe ModelMetadataGetProducerName(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)];
|
||||
|
||||
let value = match char_p_to_string(str_bytes) {
|
||||
@@ -47,7 +52,7 @@ impl<'s> ModelMetadata<'s> {
|
||||
|
||||
/// Gets the model name, returning an error if no name is present.
|
||||
pub fn name(&self) -> Result<String> {
|
||||
let mut str_bytes: *mut c_char = std::ptr::null_mut();
|
||||
let mut str_bytes: *mut c_char = ptr::null_mut();
|
||||
ortsys![unsafe ModelMetadataGetGraphName(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)];
|
||||
|
||||
let value = match char_p_to_string(str_bytes) {
|
||||
@@ -70,7 +75,7 @@ impl<'s> ModelMetadata<'s> {
|
||||
|
||||
/// Fetch the value of a custom metadata key. Returns `Ok(None)` if the key is not found.
|
||||
pub fn custom(&self, key: &str) -> Result<Option<String>> {
|
||||
let mut str_bytes: *mut c_char = std::ptr::null_mut();
|
||||
let mut str_bytes: *mut c_char = ptr::null_mut();
|
||||
let key_str = CString::new(key)?;
|
||||
ortsys![unsafe ModelMetadataLookupCustomMetadataMap(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), key_str.as_ptr(), &mut str_bytes)?];
|
||||
if !str_bytes.is_null() {
|
||||
@@ -89,11 +94,11 @@ impl<'s> ModelMetadata<'s> {
|
||||
}
|
||||
|
||||
pub fn custom_keys(&self) -> Result<Vec<String>> {
|
||||
let mut keys: *mut *mut c_char = std::ptr::null_mut();
|
||||
let mut keys: *mut *mut c_char = ptr::null_mut();
|
||||
let mut key_len = 0;
|
||||
ortsys![unsafe ModelMetadataGetCustomMetadataMapKeys(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut keys, &mut key_len)?];
|
||||
if key_len != 0 && !keys.is_null() {
|
||||
let res = unsafe { std::slice::from_raw_parts(keys, key_len as usize) }
|
||||
let res = unsafe { slice::from_raw_parts(keys, key_len as usize) }
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let res = char_p_to_string(*c);
|
||||
@@ -104,7 +109,7 @@ impl<'s> ModelMetadata<'s> {
|
||||
unsafe { self.allocator.free(keys) };
|
||||
res
|
||||
} else {
|
||||
Ok(vec![])
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{ffi::CString, ptr};
|
||||
use alloc::{boxed::Box, ffi::CString, vec::Vec};
|
||||
use core::ptr;
|
||||
|
||||
use super::{
|
||||
Operator, ShapeInferenceContext,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use std::{
|
||||
ffi::{CString, c_char, c_void},
|
||||
use alloc::{boxed::Box, ffi::CString, string::String, vec, vec::Vec};
|
||||
use core::{
|
||||
ffi::{c_char, c_void},
|
||||
ops::{Deref, DerefMut},
|
||||
ptr::{self, NonNull}
|
||||
ptr::{self, NonNull},
|
||||
slice
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -199,12 +201,12 @@ impl<T> Deref for ScratchBuffer<T> {
|
||||
type Target = [T];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe { std::slice::from_raw_parts(self.buffer.cast_const(), self.size) }
|
||||
unsafe { slice::from_raw_parts(self.buffer.cast_const(), self.size) }
|
||||
}
|
||||
}
|
||||
impl<T> DerefMut for ScratchBuffer<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.buffer, self.size) }
|
||||
unsafe { slice::from_raw_parts_mut(self.buffer, self.size) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,7 +285,7 @@ impl KernelContext {
|
||||
// unsafe KernelContext_GetScratchBuffer(
|
||||
// self.ptr.as_ptr(),
|
||||
// memory_info.ptr.as_ptr(),
|
||||
// len * std::mem::size_of::<T>(),
|
||||
// len * core::mem::size_of::<T>(),
|
||||
// &mut buffer
|
||||
// )?;
|
||||
// nonNull(buffer)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
//! Contains traits for implementing custom operator domains & kernels.
|
||||
|
||||
use std::{
|
||||
ffi::CString,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
use alloc::{boxed::Box, ffi::CString, vec::Vec};
|
||||
use core::ptr::{self, NonNull};
|
||||
|
||||
pub(crate) mod bound;
|
||||
pub mod io;
|
||||
|
||||
@@ -77,9 +77,10 @@ impl Operator for CustomOpTwo {
|
||||
|
||||
#[test]
|
||||
fn test_custom_ops() -> crate::Result<()> {
|
||||
let model = std::fs::read("tests/data/custom_op_test.onnx").expect("");
|
||||
let session = Session::builder()?
|
||||
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
|
||||
.commit_from_file("tests/data/custom_op_test.onnx")?;
|
||||
.commit_from_memory(&model)?;
|
||||
|
||||
let allocator = session.allocator();
|
||||
let mut value1 = Tensor::<f32>::new(allocator, [3, 5])?;
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use std::{
|
||||
use alloc::{ffi::CString, sync::Arc};
|
||||
use core::{
|
||||
cell::UnsafeCell,
|
||||
ffi::{CString, c_char},
|
||||
ffi::c_char,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
ptr::NonNull,
|
||||
sync::{Arc, Mutex},
|
||||
task::{Context, Poll, Waker}
|
||||
};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use ort_sys::{OrtStatus, c_void};
|
||||
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
use alloc::{boxed::Box, sync::Arc, vec::Vec};
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use std::fmt::Write;
|
||||
use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc};
|
||||
use core::fmt::Write;
|
||||
use core::{
|
||||
any::Any,
|
||||
ffi::c_void,
|
||||
marker::PhantomData,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::Path;
|
||||
|
||||
use super::SessionBuilder;
|
||||
#[cfg(feature = "std")]
|
||||
use crate::error::{Error, ErrorCode};
|
||||
use crate::{
|
||||
AsPointer,
|
||||
environment::get_environment,
|
||||
error::{Error, ErrorCode, Result},
|
||||
error::Result,
|
||||
execution_providers::apply_execution_providers,
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
@@ -15,8 +25,8 @@ use crate::{
|
||||
|
||||
impl SessionBuilder {
|
||||
/// Downloads a pre-trained ONNX model from the given URL and builds the session.
|
||||
#[cfg(feature = "fetch-models")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
|
||||
#[cfg(all(feature = "fetch-models", feature = "std"))]
|
||||
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))]
|
||||
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
|
||||
let mut download_dir = ort_sys::internal::dirs::cache_dir()
|
||||
.expect("could not determine cache directory")
|
||||
@@ -75,6 +85,8 @@ impl SessionBuilder {
|
||||
}
|
||||
|
||||
/// Loads an ONNX model from a file and builds the session.
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
|
||||
where
|
||||
P: AsRef<Path>
|
||||
@@ -93,7 +105,7 @@ impl SessionBuilder {
|
||||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
||||
}
|
||||
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
|
||||
if let Some(prepacked_weights) = self.prepacked_weights.as_ref() {
|
||||
ortsys![unsafe CreateSessionWithPrepackedWeightsContainer(env.ptr(), model_path.as_ptr(), self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; nonNull(session_ptr)];
|
||||
} else {
|
||||
@@ -104,7 +116,7 @@ impl SessionBuilder {
|
||||
|
||||
let allocator = match &self.memory_info {
|
||||
Some(info) => {
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
|
||||
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)];
|
||||
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
|
||||
}
|
||||
@@ -133,8 +145,7 @@ impl SessionBuilder {
|
||||
inner: Arc::new(SharedSessionInner {
|
||||
session_ptr,
|
||||
allocator,
|
||||
_extras: extras,
|
||||
_environment: env
|
||||
_extras: extras
|
||||
}),
|
||||
inputs,
|
||||
outputs
|
||||
@@ -159,7 +170,7 @@ impl SessionBuilder {
|
||||
|
||||
/// Load an ONNX graph from memory and commit the session.
|
||||
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> {
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
|
||||
|
||||
let env = get_environment()?;
|
||||
apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?;
|
||||
@@ -168,7 +179,7 @@ impl SessionBuilder {
|
||||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
|
||||
}
|
||||
|
||||
let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
|
||||
let model_data = model_bytes.as_ptr().cast::<c_void>();
|
||||
let model_data_length = model_bytes.len();
|
||||
if let Some(prepacked_weights) = self.prepacked_weights.as_ref() {
|
||||
ortsys![
|
||||
@@ -186,7 +197,7 @@ impl SessionBuilder {
|
||||
|
||||
let allocator = match &self.memory_info {
|
||||
Some(info) => {
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
|
||||
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)];
|
||||
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
|
||||
}
|
||||
@@ -215,8 +226,7 @@ impl SessionBuilder {
|
||||
inner: Arc::new(SharedSessionInner {
|
||||
session_ptr,
|
||||
allocator,
|
||||
_extras: extras,
|
||||
_environment: env
|
||||
_extras: extras
|
||||
}),
|
||||
inputs,
|
||||
outputs
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use std::{
|
||||
use alloc::{borrow::Cow, ffi::CString, rc::Rc, sync::Arc, vec::Vec};
|
||||
use core::{
|
||||
any::Any,
|
||||
borrow::Cow,
|
||||
ffi::{CString, c_char},
|
||||
path::Path,
|
||||
ptr,
|
||||
rc::Rc,
|
||||
sync::Arc
|
||||
ffi::{c_char, c_void},
|
||||
ptr
|
||||
};
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::Path;
|
||||
|
||||
use super::SessionBuilder;
|
||||
#[cfg(feature = "std")]
|
||||
use crate::util::path_to_os_char;
|
||||
use crate::{
|
||||
AsPointer,
|
||||
environment::{self, ThreadManager},
|
||||
@@ -17,7 +18,6 @@ use crate::{
|
||||
memory::MemoryInfo,
|
||||
operator::OperatorDomain,
|
||||
ortsys,
|
||||
util::path_to_os_char,
|
||||
value::DynValue
|
||||
};
|
||||
|
||||
@@ -89,6 +89,8 @@ impl SessionBuilder {
|
||||
/// newly optimized model to the given path (for 'offline' graph optimization).
|
||||
///
|
||||
/// Note that the file will only be created after the model is committed.
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> Result<Self> {
|
||||
let path = crate::util::path_to_os_char(path);
|
||||
ortsys![unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr())?];
|
||||
@@ -99,6 +101,8 @@ impl SessionBuilder {
|
||||
/// See [`Session::end_profiling`].
|
||||
///
|
||||
/// [`Session::end_profiling`]: crate::session::Session::end_profiling
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> Result<Self> {
|
||||
let profiling_file = crate::util::path_to_os_char(profiling_file);
|
||||
ortsys![unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr())?];
|
||||
@@ -124,6 +128,8 @@ impl SessionBuilder {
|
||||
}
|
||||
|
||||
/// Registers a custom operator library at the given library path.
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path_cstr = path_to_os_char(lib_path);
|
||||
ortsys![unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr())?];
|
||||
@@ -160,12 +166,21 @@ impl SessionBuilder {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_external_initializer_file(mut self, file_name: impl AsRef<Path>, buffer: Cow<'static, [u8]>) -> Result<Self> {
|
||||
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<str>, buffer: Cow<'static, [u8]>) -> Result<Self> {
|
||||
// We need to hold onto `buffer` until the session is actually committed. This means `buffer` must outlive 'self (if
|
||||
// SessionBuilder were to have a lifetime). Adding a lifetime to SessionBuilder would be breaking, so right now we
|
||||
// either accept a &'static [u8] or Vec<u8> via Cow<'_, [u8]>, which still allows users to use include_bytes!.
|
||||
// either accept a &'static [u8] or Vec<u8> via Cow<'static, [u8]>, which still allows users to use include_bytes!.
|
||||
|
||||
let file_name = crate::util::path_to_os_char(file_name);
|
||||
#[cfg(target_family = "windows")]
|
||||
let file_name: Vec<u16> = file_name.as_ref().encode_utf16().chain(core::iter::once(0)).collect();
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
let file_name: Vec<core::ffi::c_char> = file_name
|
||||
.as_ref()
|
||||
.as_bytes()
|
||||
.into_iter()
|
||||
.map(|c| *c as _)
|
||||
.chain(core::iter::once(0))
|
||||
.collect();
|
||||
let sizes = [buffer.len()];
|
||||
ortsys![unsafe AddExternalInitializersFromMemory(self.ptr_mut(), &file_name.as_ptr(), &buffer.as_ptr().cast::<c_char>().cast_mut(), sizes.as_ptr(), 1)?];
|
||||
self.external_initializer_buffers.push(buffer);
|
||||
@@ -204,7 +219,7 @@ impl SessionBuilder {
|
||||
|
||||
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
|
||||
let manager = Rc::new(manager);
|
||||
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut std::ffi::c_void)?];
|
||||
ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut c_void)?];
|
||||
ortsys![unsafe SessionOptionsSetCustomCreateThreadFn(self.ptr_mut(), Some(environment::thread_create::<T>))?];
|
||||
ortsys![unsafe SessionOptionsSetCustomJoinThreadFn(self.ptr_mut(), Some(environment::thread_join::<T>))?];
|
||||
self.thread_manager = Some(manager as Rc<dyn Any>);
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
use std::{
|
||||
use alloc::{borrow::Cow, ffi::CString, rc::Rc, sync::Arc, vec::Vec};
|
||||
use core::{
|
||||
any::Any,
|
||||
borrow::Cow,
|
||||
ffi::CString,
|
||||
ptr::{self, NonNull},
|
||||
rc::Rc,
|
||||
sync::Arc
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -89,7 +86,7 @@ impl SessionBuilder {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = std::ptr::null_mut();
|
||||
let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = ptr::null_mut();
|
||||
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr)?; nonNull(session_options_ptr)];
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{borrow::Cow, collections::HashMap, ops::Deref};
|
||||
use alloc::{borrow::Cow, vec::Vec};
|
||||
use core::ops::Deref;
|
||||
|
||||
use crate::value::{DynValueTypeMarker, Value, ValueRef, ValueRefMut, ValueTypeMarker};
|
||||
|
||||
@@ -50,8 +51,10 @@ pub enum SessionInputs<'i, 'v, const N: usize = 0> {
|
||||
ValueArray([SessionInputValue<'v>; N])
|
||||
}
|
||||
|
||||
impl<'i, 'v, K: Into<Cow<'i, str>>, V: Into<SessionInputValue<'v>>> From<HashMap<K, V>> for SessionInputs<'i, 'v> {
|
||||
fn from(val: HashMap<K, V>) -> Self {
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
impl<'i, 'v, K: Into<Cow<'i, str>>, V: Into<SessionInputValue<'v>>> From<std::collections::HashMap<K, V>> for SessionInputs<'i, 'v> {
|
||||
fn from(val: std::collections::HashMap<K, V>) -> Self {
|
||||
SessionInputs::ValueMap(val.into_iter().map(|(k, v)| (k.into(), v.into())).collect())
|
||||
}
|
||||
}
|
||||
@@ -108,10 +111,10 @@ impl<'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<'_,
|
||||
#[macro_export]
|
||||
macro_rules! inputs {
|
||||
($($v:expr),+ $(,)?) => (
|
||||
[$(::std::convert::Into::<$crate::session::SessionInputValue<'_>>::into($v)),+]
|
||||
[$($crate::__private::core::convert::Into::<$crate::session::SessionInputValue<'_>>::into($v)),+]
|
||||
);
|
||||
($($n:expr => $v:expr),+ $(,)?) => (
|
||||
vec![$((::std::borrow::Cow::<str>::from($n), $crate::session::SessionInputValue::<'_>::from($v)),)+]
|
||||
vec![$(($crate::__private::alloc::borrow::Cow::<str>::from($n), $crate::session::SessionInputValue::<'_>::from($v)),)+]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -123,6 +126,7 @@ mod tests {
|
||||
use crate::value::{DynTensor, Tensor};
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_hashmap_static_keys() -> crate::Result<()> {
|
||||
let v: Vec<f32> = vec![1., 2., 3., 4., 5.];
|
||||
let shape = vec![v.len() as i64];
|
||||
@@ -135,6 +139,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_hashmap_string_keys() -> crate::Result<()> {
|
||||
let v: Vec<f32> = vec![1., 2., 3., 4., 5.];
|
||||
let shape = vec![v.len() as i64];
|
||||
|
||||
@@ -10,19 +10,18 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use std::{
|
||||
use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec::Vec};
|
||||
use core::{
|
||||
any::Any,
|
||||
ffi::{CStr, CString},
|
||||
ffi::{CStr, c_char},
|
||||
iter,
|
||||
marker::PhantomData,
|
||||
ops::Deref,
|
||||
os::raw::c_char,
|
||||
ptr::NonNull,
|
||||
sync::Arc
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
use crate::{
|
||||
AsPointer, char_p_to_string,
|
||||
environment::Environment,
|
||||
error::{Error, ErrorCode, Result, assert_non_null_pointer, status_to_result},
|
||||
io_binding::IoBinding,
|
||||
memory::Allocator,
|
||||
@@ -31,21 +30,22 @@ use crate::{
|
||||
value::{Value, ValueType}
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod r#async;
|
||||
pub mod builder;
|
||||
pub mod input;
|
||||
pub mod output;
|
||||
pub mod run_options;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::r#async::InferenceFut;
|
||||
#[cfg(feature = "std")]
|
||||
use self::r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef};
|
||||
use self::builder::SessionBuilder;
|
||||
pub use self::{
|
||||
r#async::InferenceFut,
|
||||
input::{SessionInputValue, SessionInputs},
|
||||
output::SessionOutputs,
|
||||
run_options::{HasSelectedOutputs, NoSelectedOutputs, RunOptions, SelectedOutputMarker}
|
||||
};
|
||||
use self::{
|
||||
r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef},
|
||||
builder::SessionBuilder
|
||||
};
|
||||
|
||||
/// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator.
|
||||
///
|
||||
@@ -57,8 +57,7 @@ pub struct SharedSessionInner {
|
||||
pub(crate) allocator: Allocator,
|
||||
/// Additional things we may need to hold onto for the duration of this session, like `OperatorDomain`s and
|
||||
/// DLL handles for operator libraries.
|
||||
_extras: Vec<Box<dyn Any>>,
|
||||
_environment: Arc<Environment>
|
||||
_extras: Vec<Box<dyn Any>>
|
||||
}
|
||||
|
||||
unsafe impl Send for SharedSessionInner {}
|
||||
@@ -168,10 +167,10 @@ impl Session {
|
||||
let allocator = Allocator::default();
|
||||
(0..size)
|
||||
.map(|i| {
|
||||
let mut name: *mut c_char = std::ptr::null_mut();
|
||||
let mut name: *mut c_char = ptr::null_mut();
|
||||
ortsys![unsafe SessionGetOverridableInitializerName(self.ptr(), i, allocator.ptr().cast_mut(), &mut name).expect("infallible")];
|
||||
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy().into_owned();
|
||||
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
|
||||
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe SessionGetOverridableInitializerTypeInfo(self.ptr(), i, &mut typeinfo_ptr).expect("infallible")];
|
||||
let dtype = ValueType::from_type_info(typeinfo_ptr);
|
||||
OverridableInitializer { name, dtype }
|
||||
@@ -269,7 +268,7 @@ impl Session {
|
||||
|
||||
let (output_names, mut output_tensors) = match run_options {
|
||||
Some(r) => r.outputs.resolve_outputs(&self.outputs),
|
||||
None => (self.outputs.iter().map(|o| o.name.as_str()).collect(), std::iter::repeat_with(|| None).take(self.outputs.len()).collect())
|
||||
None => (self.outputs.iter().map(|o| o.name.as_str()).collect(), iter::repeat_with(|| None).take(self.outputs.len()).collect())
|
||||
};
|
||||
let output_names_ptr: Vec<*const c_char> = output_names
|
||||
.iter()
|
||||
@@ -280,7 +279,7 @@ impl Session {
|
||||
.iter_mut()
|
||||
.map(|c| match c {
|
||||
Some(v) => v.ptr_mut(),
|
||||
None => std::ptr::null_mut()
|
||||
None => ptr::null_mut()
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -297,7 +296,7 @@ impl Session {
|
||||
));
|
||||
}
|
||||
|
||||
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { std::ptr::null() };
|
||||
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { ptr::null() };
|
||||
|
||||
ortsys![
|
||||
unsafe Run(
|
||||
@@ -353,6 +352,8 @@ impl Session {
|
||||
/// # Ok(())
|
||||
/// # }) }
|
||||
/// ```
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))] // TODO: parking_lot
|
||||
pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(
|
||||
&'s self,
|
||||
input_values: impl Into<SessionInputs<'i, 'v, N>>
|
||||
@@ -370,6 +371,8 @@ impl Session {
|
||||
|
||||
/// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`].
|
||||
/// See [`Session::run_with_options`] and [`Session::run_async`] for more details.
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))] // TODO: parking_lot
|
||||
pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>(
|
||||
&'s self,
|
||||
input_values: impl Into<SessionInputs<'i, 'v, N>>,
|
||||
@@ -388,6 +391,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>(
|
||||
&'s self,
|
||||
input_names: &[String],
|
||||
@@ -400,7 +404,7 @@ impl Session {
|
||||
// (performance-wise) for routines involving `tokio::select!` or timeouts
|
||||
None => RunOptionsRef::Arc(Arc::new(unsafe {
|
||||
// SAFETY: transmuting from `RunOptions<NoSelectedOutputs>` to `RunOptions<O>`; safe because its just a marker
|
||||
std::mem::transmute::<RunOptions<NoSelectedOutputs>, RunOptions<O>>(RunOptions::new()?)
|
||||
core::mem::transmute::<RunOptions<NoSelectedOutputs>, RunOptions<O>>(RunOptions::new()?)
|
||||
}))
|
||||
};
|
||||
|
||||
@@ -416,7 +420,7 @@ impl Session {
|
||||
.map(|n| n.into_raw().cast_const())
|
||||
.collect();
|
||||
|
||||
let output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
|
||||
let output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![ptr::null_mut(); self.outputs.len()];
|
||||
|
||||
let input_values: Vec<_> = input_values.collect();
|
||||
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.iter().map(|input_array_ort| input_array_ort.ptr()).collect();
|
||||
@@ -456,7 +460,7 @@ impl Session {
|
||||
|
||||
/// Gets the session model metadata. See [`ModelMetadata`] for more info.
|
||||
pub fn metadata(&self) -> Result<ModelMetadata<'_>> {
|
||||
let mut metadata_ptr: *mut ort_sys::OrtModelMetadata = std::ptr::null_mut();
|
||||
let mut metadata_ptr: *mut ort_sys::OrtModelMetadata = ptr::null_mut();
|
||||
ortsys![unsafe SessionGetModelMetadata(self.inner.session_ptr.as_ptr(), &mut metadata_ptr)?; nonNull(metadata_ptr)];
|
||||
Ok(ModelMetadata::new(unsafe { NonNull::new_unchecked(metadata_ptr) }, &self.inner.allocator))
|
||||
}
|
||||
@@ -465,7 +469,7 @@ impl Session {
|
||||
///
|
||||
/// Note that this must be explicitly called at the end of profiling, otherwise the profiling file will be empty.
|
||||
pub fn end_profiling(&self) -> Result<String> {
|
||||
let mut profiling_name: *mut c_char = std::ptr::null_mut();
|
||||
let mut profiling_name: *mut c_char = ptr::null_mut();
|
||||
|
||||
ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)];
|
||||
assert_non_null_pointer(profiling_name, "ProfilingName")?;
|
||||
@@ -594,7 +598,7 @@ mod dangerous {
|
||||
allocator: &Allocator,
|
||||
i: usize
|
||||
) -> Result<String> {
|
||||
let mut name_bytes: *mut c_char = std::ptr::null_mut();
|
||||
let mut name_bytes: *mut c_char = ptr::null_mut();
|
||||
|
||||
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes) };
|
||||
unsafe { status_to_result(status) }?;
|
||||
@@ -622,7 +626,7 @@ mod dangerous {
|
||||
session_ptr: NonNull<ort_sys::OrtSession>,
|
||||
i: usize
|
||||
) -> Result<ValueType> {
|
||||
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
|
||||
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||
|
||||
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
|
||||
unsafe { status_to_result(status) }?;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::{
|
||||
ffi::c_void,
|
||||
iter::FusedIterator,
|
||||
mem::ManuallyDrop,
|
||||
@@ -252,7 +253,7 @@ impl IndexMut<usize> for SessionOutputs<'_, '_> {
|
||||
}
|
||||
|
||||
pub struct Keys<'x, 'r> {
|
||||
iter: std::slice::Iter<'x, &'r str>,
|
||||
iter: core::slice::Iter<'x, &'r str>,
|
||||
effective_len: usize
|
||||
}
|
||||
|
||||
@@ -281,8 +282,8 @@ impl ExactSizeIterator for Keys<'_, '_> {}
|
||||
impl FusedIterator for Keys<'_, '_> {}
|
||||
|
||||
pub struct Values<'x, 'k> {
|
||||
value_iter: std::slice::Iter<'x, DynValue>,
|
||||
key_iter: std::slice::Iter<'x, &'k str>,
|
||||
value_iter: core::slice::Iter<'x, DynValue>,
|
||||
key_iter: core::slice::Iter<'x, &'k str>,
|
||||
effective_len: usize
|
||||
}
|
||||
|
||||
@@ -311,8 +312,8 @@ impl ExactSizeIterator for Values<'_, '_> {}
|
||||
impl FusedIterator for Values<'_, '_> {}
|
||||
|
||||
pub struct ValuesMut<'x, 'k> {
|
||||
value_iter: std::slice::IterMut<'x, DynValue>,
|
||||
key_iter: std::slice::Iter<'x, &'k str>,
|
||||
value_iter: core::slice::IterMut<'x, DynValue>,
|
||||
key_iter: core::slice::Iter<'x, &'k str>,
|
||||
effective_len: usize
|
||||
}
|
||||
|
||||
@@ -341,8 +342,8 @@ impl ExactSizeIterator for ValuesMut<'_, '_> {}
|
||||
impl FusedIterator for ValuesMut<'_, '_> {}
|
||||
|
||||
pub struct Iter<'x, 'k> {
|
||||
value_iter: std::slice::Iter<'x, DynValue>,
|
||||
key_iter: std::slice::Iter<'x, &'k str>,
|
||||
value_iter: core::slice::Iter<'x, DynValue>,
|
||||
key_iter: core::slice::Iter<'x, &'k str>,
|
||||
effective_len: usize
|
||||
}
|
||||
|
||||
@@ -371,8 +372,8 @@ impl ExactSizeIterator for Iter<'_, '_> {}
|
||||
impl FusedIterator for Iter<'_, '_> {}
|
||||
|
||||
pub struct IterMut<'x, 'k> {
|
||||
value_iter: std::slice::IterMut<'x, DynValue>,
|
||||
key_iter: std::slice::Iter<'x, &'k str>,
|
||||
value_iter: core::slice::IterMut<'x, DynValue>,
|
||||
key_iter: core::slice::Iter<'x, &'k str>,
|
||||
effective_len: usize
|
||||
}
|
||||
|
||||
@@ -401,8 +402,8 @@ impl ExactSizeIterator for IterMut<'_, '_> {}
|
||||
impl FusedIterator for IterMut<'_, '_> {}
|
||||
|
||||
pub struct IntoIter<'r, 's> {
|
||||
keys: std::vec::IntoIter<&'r str>,
|
||||
values: std::vec::IntoIter<DynValue>,
|
||||
keys: alloc::vec::IntoIter<&'r str>,
|
||||
values: alloc::vec::IntoIter<DynValue>,
|
||||
effective_len: usize,
|
||||
backing_ptr: Option<(&'s Allocator, *mut c_void)>
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ffi::{CStr, CString, c_char},
|
||||
use alloc::{ffi::CString, string::String, sync::Arc, vec::Vec};
|
||||
use core::{
|
||||
ffi::{CStr, c_char},
|
||||
marker::PhantomData,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
mem,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -12,6 +12,7 @@ use crate::{
|
||||
error::Result,
|
||||
ortsys,
|
||||
session::Output,
|
||||
util::MiniMap,
|
||||
value::{DynValue, Value, ValueTypeMarker}
|
||||
};
|
||||
|
||||
@@ -46,7 +47,7 @@ pub struct OutputSelector {
|
||||
use_defaults: bool,
|
||||
default_blocklist: Vec<String>,
|
||||
allowlist: Vec<String>,
|
||||
preallocated_outputs: HashMap<String, Value>
|
||||
preallocated_outputs: MiniMap<String, Value>
|
||||
}
|
||||
|
||||
impl Default for OutputSelector {
|
||||
@@ -57,7 +58,7 @@ impl Default for OutputSelector {
|
||||
use_defaults: true,
|
||||
allowlist: Vec::new(),
|
||||
default_blocklist: Vec::new(),
|
||||
preallocated_outputs: HashMap::new()
|
||||
preallocated_outputs: MiniMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -178,7 +179,7 @@ unsafe impl Sync for RunOptions<NoSelectedOutputs> {}
|
||||
impl RunOptions {
|
||||
/// Creates a new [`RunOptions`] struct.
|
||||
pub fn new() -> Result<RunOptions<NoSelectedOutputs>> {
|
||||
let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut();
|
||||
let mut run_options_ptr: *mut ort_sys::OrtRunOptions = ptr::null_mut();
|
||||
ortsys![unsafe CreateRunOptions(&mut run_options_ptr)?; nonNull(run_options_ptr)];
|
||||
Ok(RunOptions {
|
||||
run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) },
|
||||
@@ -218,7 +219,7 @@ impl<O: SelectedOutputMarker> RunOptions<O> {
|
||||
/// ```
|
||||
pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions<HasSelectedOutputs> {
|
||||
self.outputs = outputs;
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { mem::transmute(self) }
|
||||
}
|
||||
|
||||
/// Sets a tag to identify this run in logs.
|
||||
|
||||
@@ -1,44 +1,27 @@
|
||||
//! Helper traits to extend [`ndarray`] functionality.
|
||||
|
||||
use core::ops::{DivAssign, SubAssign};
|
||||
|
||||
use ndarray::{Array, ArrayBase};
|
||||
|
||||
/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
|
||||
/// with useful tensor operations.
|
||||
///
|
||||
/// # Generic
|
||||
///
|
||||
/// The trait is generic over:
|
||||
/// * `S`: [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)'s data container
|
||||
/// * `T`: Type contained inside the tensor (for example `f32`)
|
||||
/// * `D`: Tensor's dimension ([`ndarray::Dimension`](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html))
|
||||
pub trait ArrayExtensions<S, T, D> {
|
||||
/// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis
|
||||
///
|
||||
/// # Trait Bounds
|
||||
///
|
||||
/// The function is generic and thus has some trait bounds:
|
||||
/// * `D: ndarray::RemoveAxis`: The summation over an axis reduces the dimension of the tensor. A 0-D tensor thus
|
||||
/// cannot have a softmax calculated.
|
||||
/// * `S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>`: The storage of the tensor can be an owned
|
||||
/// array ([`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)) or an array view
|
||||
/// ([`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)).
|
||||
/// * `<S as ndarray::RawData>::Elem: std::clone::Clone`: The elements of the tensor must be `Clone`.
|
||||
/// * `T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign`: The elements of the tensor must be workable
|
||||
/// as floats and must support `-=` and `/=` operations.
|
||||
/// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis.
|
||||
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
|
||||
where
|
||||
D: ndarray::RemoveAxis,
|
||||
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
|
||||
<S as ndarray::RawData>::Elem: std::clone::Clone,
|
||||
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
|
||||
<S as ndarray::RawData>::Elem: Clone,
|
||||
T: ndarray::NdFloat + SubAssign + DivAssign;
|
||||
}
|
||||
|
||||
impl<S, T, D> ArrayExtensions<S, T, D> for ArrayBase<S, D>
|
||||
where
|
||||
D: ndarray::RemoveAxis,
|
||||
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
|
||||
<S as ndarray::RawData>::Elem: std::clone::Clone,
|
||||
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign
|
||||
<S as ndarray::RawData>::Elem: Clone,
|
||||
T: ndarray::NdFloat + SubAssign + DivAssign
|
||||
{
|
||||
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
|
||||
let mut new_array: Array<T, D> = self.to_owned();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::fmt;
|
||||
use alloc::string::String;
|
||||
use core::fmt;
|
||||
#[cfg(feature = "ndarray")]
|
||||
use std::ptr;
|
||||
use core::{ffi::c_void, ptr};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
use crate::{error::Result, ortsys};
|
||||
@@ -194,7 +195,7 @@ pub(crate) fn extract_primitive_array<'t, T>(shape: ndarray::IxDyn, tensor: *con
|
||||
// Get pointer to output tensor values
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(tensor.cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) };
|
||||
@@ -210,7 +211,7 @@ pub(crate) fn extract_primitive_array_mut<'t, T>(shape: ndarray::IxDyn, tensor:
|
||||
// Get pointer to output tensor values
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(tensor, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
let array_view = unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape, output_array_ptr) };
|
||||
|
||||
275
src/util.rs
275
src/util.rs
@@ -1,18 +1,29 @@
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
use std::os::raw::c_char;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
#[cfg(target_family = "windows")]
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::{ffi::OsString, path::Path};
|
||||
use alloc::vec::Vec;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use core::sync::atomic::Ordering;
|
||||
use core::{
|
||||
borrow::Borrow,
|
||||
cell::UnsafeCell,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
mem::{self, MaybeUninit}
|
||||
};
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
#[cfg(all(feature = "std", target_family = "windows"))]
|
||||
type OsCharArray = Vec<u16>;
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
type OsCharArray = Vec<c_char>;
|
||||
#[cfg(all(feature = "std", not(target_family = "windows")))]
|
||||
type OsCharArray = Vec<core::ffi::c_char>;
|
||||
|
||||
pub fn path_to_os_char(path: impl AsRef<Path>) -> OsCharArray {
|
||||
let model_path = OsString::from(path.as_ref());
|
||||
#[cfg(feature = "std")]
|
||||
pub fn path_to_os_char(path: impl AsRef<std::path::Path>) -> OsCharArray {
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
use core::ffi::c_char;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
#[cfg(target_family = "windows")]
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
|
||||
let model_path = std::ffi::OsString::from(path.as_ref());
|
||||
#[cfg(target_family = "windows")]
|
||||
let model_path: Vec<u16> = model_path.encode_wide().chain(std::iter::once(0)).collect();
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
@@ -24,3 +35,243 @@ pub fn path_to_os_char(path: impl AsRef<Path>) -> OsCharArray {
|
||||
.collect();
|
||||
model_path
|
||||
}
|
||||
|
||||
// generally as performant or faster than HashMap<K, V> for <50 items. good enough for #[no_std]
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct MiniMap<K, V> {
|
||||
values: Vec<(K, V)>
|
||||
}
|
||||
|
||||
impl<K, V> Default for MiniMap<K, V> {
|
||||
fn default() -> Self {
|
||||
Self { values: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, V> MiniMap<K, V> {
|
||||
pub const fn new() -> Self {
|
||||
Self { values: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq, V> MiniMap<K, V> {
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<&V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Eq + ?Sized
|
||||
{
|
||||
self.values.iter().find(|(k, _)| key.eq(k.borrow())).map(|(_, v)| v)
|
||||
}
|
||||
|
||||
pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Eq + ?Sized
|
||||
{
|
||||
self.values.iter_mut().find(|(k, _)| key.eq(k.borrow())).map(|(_, v)| v)
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
match self.get_mut(&key) {
|
||||
Some(v) => Some(mem::replace(v, value)),
|
||||
None => {
|
||||
self.values.push((key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn drain(&mut self) -> alloc::vec::Drain<(K, V)> {
|
||||
self.values.drain(..)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> core::slice::Iter<'_, (K, V)> {
|
||||
self.values.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: fmt::Debug, V: fmt::Debug> fmt::Debug for MiniMap<K, V> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_map().entries(self.values.iter().map(|(k, v)| (k, v))).finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OnceLock<T> {
|
||||
data: UnsafeCell<MaybeUninit<T>>,
|
||||
#[cfg(not(feature = "std"))]
|
||||
status: core::sync::atomic::AtomicU8,
|
||||
#[cfg(feature = "std")]
|
||||
once: std::sync::Once,
|
||||
phantom: PhantomData<T>
|
||||
}
|
||||
|
||||
unsafe impl<T: Send> Send for OnceLock<T> {}
|
||||
unsafe impl<T: Send + Sync> Sync for OnceLock<T> {}
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
const STATUS_UNINITIALIZED: u8 = 0;
|
||||
#[cfg(not(feature = "std"))]
|
||||
const STATUS_RUNNING: u8 = 1;
|
||||
#[cfg(not(feature = "std"))]
|
||||
const STATUS_INITIALIZED: u8 = 2;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
impl<T> OnceLock<T> {
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
data: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
status: core::sync::atomic::AtomicU8::new(STATUS_UNINITIALIZED),
|
||||
phantom: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_or_init<F: FnOnce() -> T>(&self, f: F) -> &T {
|
||||
match self.get_or_try_init(|| Ok::<T, core::convert::Infallible>(f())) {
|
||||
Ok(x) => x,
|
||||
Err(e) => match e {}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self) -> Option<&T> {
|
||||
match self.status.load(Ordering::Acquire) {
|
||||
STATUS_INITIALIZED => Some(unsafe { self.get_unchecked() }),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn get_unchecked(&self) -> &T {
|
||||
&*(*self.data.get()).as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_or_try_init<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
|
||||
if let Some(value) = self.get() { Ok(value) } else { self.try_init_inner(f) }
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn try_init_inner<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
|
||||
'a: loop {
|
||||
match self
|
||||
.status
|
||||
.compare_exchange(STATUS_UNINITIALIZED, STATUS_RUNNING, Ordering::Acquire, Ordering::Acquire)
|
||||
{
|
||||
Ok(_) => {
|
||||
struct SetStatusOnPanic<'a> {
|
||||
status: &'a core::sync::atomic::AtomicU8
|
||||
}
|
||||
impl Drop for SetStatusOnPanic<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.status.store(STATUS_UNINITIALIZED, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let panic_catcher = SetStatusOnPanic { status: &self.status };
|
||||
let val = match f() {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
core::mem::forget(panic_catcher);
|
||||
self.status.store(STATUS_UNINITIALIZED, Ordering::Release);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
unsafe {
|
||||
*(*self.data.get()).as_mut_ptr() = val;
|
||||
};
|
||||
core::mem::forget(panic_catcher);
|
||||
|
||||
self.status.store(STATUS_INITIALIZED, Ordering::Release);
|
||||
|
||||
return Ok(unsafe { self.get_unchecked() });
|
||||
}
|
||||
Err(STATUS_INITIALIZED) => return Ok(unsafe { self.get_unchecked() }),
|
||||
Err(STATUS_RUNNING) => loop {
|
||||
match self.status.load(Ordering::Acquire) {
|
||||
STATUS_RUNNING => core::hint::spin_loop(),
|
||||
STATUS_INITIALIZED => return Ok(unsafe { self.get_unchecked() }),
|
||||
// STATUS_UNINITIALIZED - running thread failed, time for us to step in
|
||||
_ => continue 'a
|
||||
}
|
||||
},
|
||||
_ => continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl<T> OnceLock<T> {
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
data: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
once: std::sync::Once::new(),
|
||||
phantom: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_or_init<F: FnOnce() -> T>(&self, f: F) -> &T {
|
||||
match self.get_or_try_init(|| Ok::<T, core::convert::Infallible>(f())) {
|
||||
Ok(x) => x,
|
||||
Err(e) => match e {}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self) -> Option<&T> {
|
||||
if self.once.is_completed() { Some(unsafe { self.get_unchecked() }) } else { None }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn get_unchecked(&self) -> &T {
|
||||
&*(*self.data.get()).as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_or_try_init<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
|
||||
if let Some(value) = self.get() { Ok(value) } else { self.try_init_inner(f) }
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn try_init_inner<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
|
||||
let mut res: Result<(), E> = Ok(());
|
||||
let slot = &self.data;
|
||||
self.once.call_once_force(|_| match f() {
|
||||
Ok(value) => unsafe {
|
||||
(*slot.get()).write(value);
|
||||
},
|
||||
Err(e) => {
|
||||
res = Err(e);
|
||||
}
|
||||
});
|
||||
res.map(|_| unsafe { self.get_unchecked() })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> OnceLock<T> {
|
||||
pub fn try_insert(&self, value: T) -> bool {
|
||||
let mut container = Some(value);
|
||||
self.get_or_init(|| unsafe { container.take().unwrap_unchecked() });
|
||||
container.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for OnceLock<T> {
|
||||
fn drop(&mut self) {
|
||||
#[cfg(not(feature = "std"))]
|
||||
let status = *self.status.get_mut() == STATUS_INITIALIZED;
|
||||
#[cfg(feature = "std")]
|
||||
let status = self.once.is_completed();
|
||||
if status {
|
||||
unsafe {
|
||||
core::ptr::drop_in_place((*self.data.get()).as_mut_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
sync::Arc,
|
||||
vec,
|
||||
vec::Vec
|
||||
};
|
||||
use core::{
|
||||
ffi::c_void,
|
||||
fmt::Debug,
|
||||
hash::Hash,
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
slice
|
||||
};
|
||||
#[cfg(feature = "std")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{
|
||||
DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker,
|
||||
@@ -77,7 +88,7 @@ pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType<K, V>>;
|
||||
pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType<K, V>>;
|
||||
|
||||
impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
|
||||
pub fn try_extract_raw_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<Vec<(K, V)>> {
|
||||
match self.dtype() {
|
||||
ValueType::Map { key, value } => {
|
||||
let k_type = K::into_tensor_element_type();
|
||||
@@ -112,11 +123,11 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
if *ty == K::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut K = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
let len = calculate_tensor_size(dimensions);
|
||||
(dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })
|
||||
(dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) })
|
||||
} else {
|
||||
return Err(Error::new_with_code(
|
||||
ErrorCode::InvalidArgument,
|
||||
@@ -146,13 +157,13 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
for i in 0..key_tensor_shape[0] as usize {
|
||||
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
|
||||
}
|
||||
Ok(vec.into_iter().collect())
|
||||
Ok(vec)
|
||||
} else {
|
||||
let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_string_tensor()?;
|
||||
// SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`,
|
||||
// so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type
|
||||
// checker.
|
||||
let key_tensor: Vec<K> = unsafe { std::mem::transmute(key_tensor) };
|
||||
let key_tensor: Vec<K> = unsafe { mem::transmute(key_tensor) };
|
||||
|
||||
let mut value_tensor_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr().cast_mut(), &mut value_tensor_ptr)?; nonNull(value_tensor_ptr)];
|
||||
@@ -176,6 +187,12 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
|
||||
self.try_extract_raw_map().map(|c| c.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: PrimitiveTensorElementType + Debug + Clone + Hash + Eq + 'static, V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
|
||||
@@ -267,6 +284,12 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTens
|
||||
}
|
||||
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: PrimitiveTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
|
||||
pub fn extract_raw_map(&self) -> Vec<(K, V)> {
|
||||
self.try_extract_raw_map().expect("Failed to extract map")
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
|
||||
pub fn extract_map(&self) -> HashMap<K, V> {
|
||||
self.try_extract_map().expect("Failed to extract map")
|
||||
}
|
||||
@@ -276,7 +299,7 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
|
||||
/// Converts from a strongly-typed [`Map<K, V>`] to a type-erased [`DynMap`].
|
||||
#[inline]
|
||||
pub fn upcast(self) -> DynMap {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { mem::transmute(self) }
|
||||
}
|
||||
|
||||
/// Converts from a strongly-typed [`Map<K, V>`] to a reference to a type-erased [`DynMap`].
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
use std::{
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
sync::Arc,
|
||||
vec::Vec
|
||||
};
|
||||
use core::{
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
mem,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
use super::{DowncastableTarget, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
|
||||
@@ -158,7 +165,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
|
||||
/// Converts from a strongly-typed [`Sequence<T>`] to a type-erased [`DynSequence`].
|
||||
#[inline]
|
||||
pub fn upcast(self) -> DynSequence {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { mem::transmute(self) }
|
||||
}
|
||||
|
||||
/// Converts from a strongly-typed [`Sequence<T>`] to a reference to a type-erased [`DynSequence`].
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::{
|
||||
use alloc::{boxed::Box, ffi::CString, format, string::String, sync::Arc, vec, vec::Vec};
|
||||
use core::{
|
||||
any::Any,
|
||||
ffi,
|
||||
ffi::c_void,
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
mem::size_of,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
@@ -59,11 +59,11 @@ impl Tensor<String> {
|
||||
];
|
||||
|
||||
// create null-terminated copies of each string, as per `FillStringTensor` docs
|
||||
let null_terminated_copies: Vec<ffi::CString> = data
|
||||
let null_terminated_copies: Vec<CString> = data
|
||||
.iter()
|
||||
.map(|elt| {
|
||||
let slice = elt.as_utf8_bytes();
|
||||
ffi::CString::new(slice)
|
||||
CString::new(slice)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(Error::wrap)?;
|
||||
@@ -130,7 +130,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
// do it manually.
|
||||
let memory_info = MemoryInfo::from_value(value_ptr).expect("CreateTensorAsOrtValue returned non-tensor");
|
||||
if memory_info.is_cpu_accessible() {
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];
|
||||
|
||||
unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * size_of::<T>()) };
|
||||
@@ -188,14 +188,14 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
let shape_ptr: *const i64 = shape.as_ptr();
|
||||
let shape_len = shape.len();
|
||||
|
||||
let tensor_values_ptr: *mut std::ffi::c_void = ptr.cast();
|
||||
let tensor_values_ptr: *mut c_void = ptr.cast();
|
||||
assert_non_null_pointer(tensor_values_ptr, "TensorValues")?;
|
||||
|
||||
ortsys:
|
||||
@@ -401,7 +401,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
let shape_ptr: *const i64 = shape.as_ptr();
|
||||
let shape_len = shape.len();
|
||||
|
||||
let data_len = calculate_tensor_size(&shape) * std::mem::size_of::<T>();
|
||||
let data_len = calculate_tensor_size(&shape) * size_of::<T>();
|
||||
|
||||
ortsys![
|
||||
unsafe CreateTensorWithDataAsOrtValue(
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
use std::{fmt::Debug, ptr, string::FromUtf8Error};
|
||||
use alloc::{
|
||||
format,
|
||||
string::{FromUtf8Error, String},
|
||||
vec,
|
||||
vec::Vec
|
||||
};
|
||||
use core::{ffi::c_void, fmt::Debug, ptr, slice};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
use ndarray::IxDyn;
|
||||
@@ -107,7 +113,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
Ok(unsafe { *output_array_ptr })
|
||||
@@ -211,11 +217,11 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
if *ty == T::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
let len = calculate_tensor_size(dimensions);
|
||||
Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }))
|
||||
Ok((dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) }))
|
||||
} else {
|
||||
Err(Error::new_with_code(
|
||||
ErrorCode::InvalidArgument,
|
||||
@@ -264,11 +270,11 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
if *ty == T::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];
|
||||
|
||||
let len = calculate_tensor_size(dimensions);
|
||||
Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) }))
|
||||
Ok((dimensions, unsafe { slice::from_raw_parts_mut(output_array_ptr, len) }))
|
||||
} else {
|
||||
Err(Error::new_with_code(
|
||||
ErrorCode::InvalidArgument,
|
||||
@@ -424,7 +430,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn shape(&self) -> Result<Vec<i64>> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr)?];
|
||||
|
||||
let res = {
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
mod create;
|
||||
mod extract;
|
||||
|
||||
use std::{
|
||||
use alloc::{
|
||||
format,
|
||||
string::{String, ToString},
|
||||
sync::Arc
|
||||
};
|
||||
use core::{
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
ops::{Index, IndexMut},
|
||||
sync::Arc
|
||||
ptr
|
||||
};
|
||||
|
||||
pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts};
|
||||
@@ -88,7 +94,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn data_ptr_mut(&mut self) -> Result<*mut ort_sys::c_void> {
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr_mut(), &mut buffer_ptr)?; nonNull(buffer_ptr)];
|
||||
Ok(buffer_ptr)
|
||||
}
|
||||
@@ -111,7 +117,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn data_ptr(&self) -> Result<*const ort_sys::c_void> {
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), &mut buffer_ptr)?; nonNull(buffer_ptr)];
|
||||
Ok(buffer_ptr)
|
||||
}
|
||||
@@ -157,7 +163,7 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn upcast(self) -> DynTensor {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { mem::transmute(self) }
|
||||
}
|
||||
|
||||
/// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor<T>`].
|
||||
@@ -237,7 +243,7 @@ impl<T: IntoTensorElementType + Clone + Debug, const N: usize> Index<[i64; N]> f
|
||||
panic!("Cannot directly index a tensor which is not allocated on the CPU.");
|
||||
}
|
||||
|
||||
let mut out: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
let mut out: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe TensorAt(self.ptr().cast_mut(), index.as_ptr(), N, &mut out).expect("Failed to index tensor")];
|
||||
unsafe { &*out.cast::<T>() }
|
||||
}
|
||||
@@ -248,7 +254,7 @@ impl<T: IntoTensorElementType + Clone + Debug, const N: usize> IndexMut<[i64; N]
|
||||
panic!("Cannot directly index a tensor which is not allocated on the CPU.");
|
||||
}
|
||||
|
||||
let mut out: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
let mut out: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe TensorAt(self.ptr_mut(), index.as_ptr(), N, &mut out).expect("Failed to index tensor")];
|
||||
unsafe { &mut *out.cast::<T>() }
|
||||
}
|
||||
@@ -267,8 +273,9 @@ pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use alloc::sync::Arc;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
use ndarray::{ArcArray1, Array1, CowArray};
|
||||
|
||||
use super::Tensor;
|
||||
|
||||
@@ -16,14 +16,19 @@
|
||||
//!
|
||||
//! ONNX Runtime also supports [`Sequence`]s and [`Map`]s, though they are less commonly used.
|
||||
|
||||
use std::{
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
sync::Arc
|
||||
};
|
||||
use core::{
|
||||
any::Any,
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
mem::transmute,
|
||||
ops::{Deref, DerefMut},
|
||||
ptr::NonNull,
|
||||
sync::Arc
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
|
||||
mod impl_map;
|
||||
@@ -102,7 +107,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> {
|
||||
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(self) -> Result<ValueRef<'v, OtherType>> {
|
||||
let dt = self.dtype();
|
||||
if OtherType::can_downcast(dt) {
|
||||
Ok(unsafe { std::mem::transmute::<ValueRef<'v, Type>, ValueRef<'v, OtherType>>(self) })
|
||||
Ok(unsafe { transmute::<ValueRef<'v, Type>, ValueRef<'v, OtherType>>(self) })
|
||||
} else {
|
||||
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format())))
|
||||
}
|
||||
@@ -118,7 +123,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> {
|
||||
}
|
||||
|
||||
pub fn into_dyn(self) -> ValueRef<'v, DynValueTypeMarker> {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { transmute(self) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +160,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> {
|
||||
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(self) -> Result<ValueRefMut<'v, OtherType>> {
|
||||
let dt = self.dtype();
|
||||
if OtherType::can_downcast(dt) {
|
||||
Ok(unsafe { std::mem::transmute::<ValueRefMut<'v, Type>, ValueRefMut<'v, OtherType>>(self) })
|
||||
Ok(unsafe { transmute::<ValueRefMut<'v, Type>, ValueRefMut<'v, OtherType>>(self) })
|
||||
} else {
|
||||
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format())))
|
||||
}
|
||||
@@ -171,7 +176,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> {
|
||||
}
|
||||
|
||||
pub fn into_dyn(self) -> ValueRefMut<'v, DynValueTypeMarker> {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { transmute(self) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,7 +314,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - `session` must be `Some` for values returned from a session.
|
||||
#[must_use]
|
||||
pub unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
|
||||
let mut typeinfo_ptr = std::ptr::null_mut();
|
||||
let mut typeinfo_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)];
|
||||
Value {
|
||||
inner: Arc::new(ValueInner {
|
||||
@@ -327,7 +332,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// contexts.
|
||||
#[must_use]
|
||||
pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
|
||||
let mut typeinfo_ptr = std::ptr::null_mut();
|
||||
let mut typeinfo_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)];
|
||||
Value {
|
||||
inner: Arc::new(ValueInner {
|
||||
@@ -353,7 +358,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
|
||||
/// Converts this value into a type-erased [`DynValue`].
|
||||
pub fn into_dyn(self) -> DynValue {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
unsafe { transmute(self) }
|
||||
}
|
||||
|
||||
pub(crate) fn clone_of(value: &Self) -> Self {
|
||||
@@ -388,7 +393,7 @@ impl Value<DynValueTypeMarker> {
|
||||
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(self) -> Result<Value<OtherType>> {
|
||||
let dt = self.dtype();
|
||||
if OtherType::can_downcast(dt) {
|
||||
Ok(unsafe { std::mem::transmute::<Value<DynValueTypeMarker>, Value<OtherType>>(self) })
|
||||
Ok(unsafe { transmute::<Value<DynValueTypeMarker>, Value<OtherType>>(self) })
|
||||
} else {
|
||||
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast {dt} to {}", OtherType::format())))
|
||||
}
|
||||
@@ -482,13 +487,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_sequence_map() -> crate::Result<()> {
|
||||
let map_contents = [("meaning".to_owned(), 42.0), ("pi".to_owned(), std::f32::consts::PI)];
|
||||
let map_contents = [("meaning".to_owned(), 42.0), ("pi".to_owned(), core::f32::consts::PI)];
|
||||
let value = Sequence::new([Map::<String, f32>::new(map_contents)?])?;
|
||||
|
||||
for map in value.extract_sequence(&Allocator::default()) {
|
||||
let map = map.extract_map();
|
||||
let map = map.extract_raw_map().into_iter().collect::<std::collections::HashMap<_, _>>();
|
||||
assert_eq!(map["meaning"], 42.0);
|
||||
assert_eq!(map["pi"], std::f32::consts::PI);
|
||||
assert_eq!(map["pi"], core::f32::consts::PI);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
use std::{
|
||||
ffi::{CStr, CString, c_char},
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
ffi::CString,
|
||||
string::{String, ToString},
|
||||
vec,
|
||||
vec::Vec
|
||||
};
|
||||
use core::{
|
||||
ffi::{CStr, c_char},
|
||||
fmt, ptr
|
||||
};
|
||||
|
||||
@@ -83,15 +90,15 @@ impl ValueType {
|
||||
ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; // infallible
|
||||
let io_type = match ty {
|
||||
ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => {
|
||||
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr)]; // infallible
|
||||
unsafe { extract_data_type_from_tensor_info(info_ptr) }
|
||||
}
|
||||
ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => {
|
||||
let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut();
|
||||
let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible
|
||||
|
||||
let mut element_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
|
||||
let mut element_type_info: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe GetSequenceElementType(info_ptr, &mut element_type_info)]; // infallible
|
||||
|
||||
let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN;
|
||||
@@ -99,13 +106,13 @@ impl ValueType {
|
||||
|
||||
match ty {
|
||||
ort_sys::ONNXType::ONNX_TYPE_TENSOR => {
|
||||
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr)]; // infallible
|
||||
let ty = unsafe { extract_data_type_from_tensor_info(info_ptr) };
|
||||
ValueType::Sequence(Box::new(ty))
|
||||
}
|
||||
ort_sys::ONNXType::ONNX_TYPE_MAP => {
|
||||
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut();
|
||||
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr)]; // infallible
|
||||
let ty = unsafe { extract_data_type_from_map_info(info_ptr) };
|
||||
ValueType::Sequence(Box::new(ty))
|
||||
@@ -114,15 +121,15 @@ impl ValueType {
|
||||
}
|
||||
}
|
||||
ort_sys::ONNXType::ONNX_TYPE_MAP => {
|
||||
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut();
|
||||
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible
|
||||
unsafe { extract_data_type_from_map_info(info_ptr) }
|
||||
}
|
||||
ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => {
|
||||
let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = std::ptr::null_mut();
|
||||
let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible
|
||||
|
||||
let mut contained_type: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
|
||||
let mut contained_type: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||
ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type)]; // infallible
|
||||
|
||||
ValueType::Optional(Box::new(ValueType::from_type_info(contained_type)))
|
||||
@@ -273,9 +280,9 @@ unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeIn
|
||||
ortsys![GetMapKeyType(info_ptr, &mut key_type_sys)]; // infallible
|
||||
assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
|
||||
let mut value_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
|
||||
let mut value_type_info: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
|
||||
ortsys![GetMapValueType(info_ptr, &mut value_type_info)]; // infallible
|
||||
let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
|
||||
ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr)]; // infallible
|
||||
let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![GetTensorElementType(value_info_ptr, &mut value_type_sys)]; // infallible
|
||||
|
||||
Reference in New Issue
Block a user