mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor(sys): replace _system macro with extern "system"
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
env, fs, io,
|
||||
env, fs,
|
||||
path::{Path, PathBuf}
|
||||
};
|
||||
|
||||
|
||||
1754
ort-sys/src/lib.rs
1754
ort-sys/src/lib.rs
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ use tracing::{Level, debug};
|
||||
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
use crate::G_ORT_DYLIB_PATH;
|
||||
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, extern_system_fn, ortsys};
|
||||
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys};
|
||||
|
||||
struct EnvironmentSingleton {
|
||||
lock: RwLock<Option<Arc<Environment>>>
|
||||
@@ -172,7 +172,7 @@ pub trait ThreadManager {
|
||||
fn join(thread: Self::Thread) -> crate::Result<()>;
|
||||
}
|
||||
|
||||
pub(crate) unsafe extern "C" fn thread_create<T: ThreadManager + Any>(
|
||||
pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
|
||||
ort_custom_thread_creation_options: *mut c_void,
|
||||
ort_thread_worker_fn: ort_sys::OrtThreadWorkerFn,
|
||||
ort_worker_fn_param: *mut c_void
|
||||
@@ -201,7 +201,7 @@ pub(crate) unsafe extern "C" fn thread_create<T: ThreadManager + Any>(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe extern "C" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
|
||||
pub(crate) unsafe extern "system" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
|
||||
let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<<T as ThreadManager>::Thread>());
|
||||
if let Err(e) = <T as ThreadManager>::join(*handle) {
|
||||
tracing::error!("Failed to join thread using manager: {e}");
|
||||
@@ -395,29 +395,29 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
|
||||
EnvironmentBuilder::new()
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
|
||||
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, _: *const c_char, id: *const c_char, code_location: *const c_char, message: *const c_char) {
|
||||
assert_ne!(code_location, ptr::null());
|
||||
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
|
||||
assert_ne!(message, ptr::null());
|
||||
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
|
||||
assert_ne!(id, ptr::null());
|
||||
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");
|
||||
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
|
||||
pub(crate) extern "system" fn custom_logger(
|
||||
_params: *mut ffi::c_void,
|
||||
severity: ort_sys::OrtLoggingLevel,
|
||||
_: *const c_char,
|
||||
id: *const c_char,
|
||||
code_location: *const c_char,
|
||||
message: *const c_char
|
||||
) {
|
||||
assert_ne!(code_location, ptr::null());
|
||||
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
|
||||
assert_ne!(message, ptr::null());
|
||||
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
|
||||
assert_ne!(id, ptr::null());
|
||||
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");
|
||||
|
||||
let span = tracing::span!(
|
||||
Level::TRACE,
|
||||
"ort",
|
||||
id = id,
|
||||
location = code_location
|
||||
);
|
||||
let span = tracing::span!(Level::TRACE, "ort", id = id, location = code_location);
|
||||
|
||||
match severity {
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::DEBUG, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::INFO, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::WARN, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL=> tracing::event!(parent: &span, Level::ERROR, "{message}")
|
||||
}
|
||||
match severity {
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::DEBUG, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::INFO, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::WARN, "{message}"),
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, Level::ERROR, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
23
src/lib.rs
23
src/lib.rs
@@ -49,23 +49,6 @@ pub use self::{
|
||||
error::{Error, ErrorCode, Result}
|
||||
};
|
||||
|
||||
#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
|
||||
macro_rules! extern_system_fn {
|
||||
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
|
||||
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
|
||||
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
|
||||
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
|
||||
}
|
||||
#[cfg(all(target_arch = "x86", target_os = "windows"))]
|
||||
macro_rules! extern_system_fn {
|
||||
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
|
||||
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
|
||||
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
|
||||
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
|
||||
}
|
||||
|
||||
pub(crate) use extern_system_fn;
|
||||
|
||||
/// The minor version of ONNX Runtime used by this version of `ort`.
|
||||
pub const MINOR_VERSION: u32 = ort_sys::ORT_API_VERSION;
|
||||
|
||||
@@ -167,7 +150,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
let base: *const ort_sys::OrtApiBase = base_getter();
|
||||
assert_ne!(base, ptr::null());
|
||||
|
||||
let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } =
|
||||
let get_version_string: unsafe extern "system" fn() -> *const c_char =
|
||||
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
|
||||
let version_string = get_version_string();
|
||||
let version_string = CStr::from_ptr(version_string).to_string_lossy();
|
||||
@@ -187,7 +170,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
),
|
||||
std::cmp::Ordering::Equal => {}
|
||||
};
|
||||
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } =
|
||||
let get_api: unsafe extern "system" fn(u32) -> *const ort_sys::OrtApi =
|
||||
(*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
|
||||
let api: *const ort_sys::OrtApi = get_api(ort_sys::ORT_API_VERSION);
|
||||
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
|
||||
@@ -196,7 +179,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
|
||||
unsafe {
|
||||
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
|
||||
assert_ne!(base, ptr::null());
|
||||
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } =
|
||||
let get_api: unsafe extern "system" fn(u32) -> *const ort_sys::OrtApi =
|
||||
(*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
|
||||
let api: *const ort_sys::OrtApi = get_api(ort_sys::ORT_API_VERSION);
|
||||
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
|
||||
|
||||
@@ -5,7 +5,7 @@ use super::{
|
||||
io::{self, InputOutputCharacteristic},
|
||||
kernel::{Kernel, KernelAttributes, KernelContext}
|
||||
};
|
||||
use crate::{Result, error::IntoStatus, extern_system_fn};
|
||||
use crate::{Result, error::IntoStatus};
|
||||
|
||||
#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later
|
||||
pub(crate) struct BoundOperator {
|
||||
@@ -62,168 +62,140 @@ impl BoundOperator {
|
||||
})
|
||||
}
|
||||
|
||||
unsafe fn safe<'a>(op: *const ort_sys::OrtCustomOp) -> &'a BoundOperator {
|
||||
&*op.cast()
|
||||
fn safe<'a>(op: *const ort_sys::OrtCustomOp) -> &'a BoundOperator {
|
||||
unsafe { &*op.cast() }
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn create_kernel(
|
||||
op: *const ort_sys::OrtCustomOp,
|
||||
_: *const ort_sys::OrtApi,
|
||||
info: *const ort_sys::OrtKernelInfo,
|
||||
kernel_ptr: *mut *mut ort_sys::c_void
|
||||
) -> *mut ort_sys::OrtStatus {
|
||||
let safe = Self::safe(op);
|
||||
let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) {
|
||||
Ok(kernel) => kernel,
|
||||
e => return e.into_status()
|
||||
};
|
||||
*kernel_ptr = (Box::leak(Box::new(kernel)) as *mut Box<dyn Kernel>).cast();
|
||||
Ok(()).into_status()
|
||||
}
|
||||
pub(crate) extern "system" fn create_kernel(
|
||||
op: *const ort_sys::OrtCustomOp,
|
||||
_: *const ort_sys::OrtApi,
|
||||
info: *const ort_sys::OrtKernelInfo,
|
||||
kernel_ptr: *mut *mut ort_sys::c_void
|
||||
) -> *mut ort_sys::OrtStatus {
|
||||
let safe = Self::safe(op);
|
||||
let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) {
|
||||
Ok(kernel) => kernel,
|
||||
e => return e.into_status()
|
||||
};
|
||||
unsafe { *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut Box<dyn Kernel>).cast() };
|
||||
Ok(()).into_status()
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
|
||||
let context = KernelContext::new(context);
|
||||
unsafe { &mut *kernel_ptr.cast::<Box<dyn Kernel>>() }.compute(&context).into_status()
|
||||
}
|
||||
pub(crate) extern "system" fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
|
||||
let context = KernelContext::new(context);
|
||||
unsafe { &mut *kernel_ptr.cast::<Box<dyn Kernel>>() }.compute(&context).into_status()
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn destroy_kernel(op_kernel: *mut ort_sys::c_void) {
|
||||
drop(Box::from_raw(op_kernel.cast::<Box<dyn Kernel>>()));
|
||||
}
|
||||
pub(crate) extern "system" fn destroy_kernel(op_kernel: *mut ort_sys::c_void) {
|
||||
drop(unsafe { Box::from_raw(op_kernel.cast::<Box<dyn Kernel>>()) });
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_name(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
|
||||
let safe = Self::safe(op);
|
||||
safe.name.as_ptr()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_execution_provider_type(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
|
||||
let safe = Self::safe(op);
|
||||
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
|
||||
}
|
||||
pub(crate) extern "system" fn get_name(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
|
||||
let safe = Self::safe(op);
|
||||
safe.name.as_ptr()
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_min_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.operator.min_version() as _
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_max_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.operator.max_version() as _
|
||||
}
|
||||
pub(crate) extern "system" fn get_execution_provider_type(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
|
||||
let safe = Self::safe(op);
|
||||
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_input_memory_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs[index].memory_type.into()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_input_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs[index].characteristic.into()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_output_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs[index].characteristic.into()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_input_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs.len()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_output_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs.len()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_input_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs[index]
|
||||
.r#type
|
||||
.map(|c| c.into())
|
||||
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_output_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs[index]
|
||||
.r#type
|
||||
.map(|c| c.into())
|
||||
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_variadic_input_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_min_arity)
|
||||
.unwrap_or(1)
|
||||
.try_into()
|
||||
.expect("input minimum arity overflows i32")
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_variadic_input_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_homogeneity)
|
||||
.unwrap_or(false)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_variadic_output_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_min_arity)
|
||||
.unwrap_or(1)
|
||||
.try_into()
|
||||
.expect("output minimum arity overflows i32")
|
||||
}
|
||||
}
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn get_variadic_output_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_homogeneity)
|
||||
.unwrap_or(false)
|
||||
.into()
|
||||
}
|
||||
pub(crate) extern "system" fn get_min_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.operator.min_version() as _
|
||||
}
|
||||
|
||||
extern_system_fn! {
|
||||
pub(crate) unsafe fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
|
||||
let safe = Self::safe(op);
|
||||
let mut ctx = ShapeInferenceContext {
|
||||
ptr: ctx
|
||||
};
|
||||
safe.operator.infer_shape(&mut ctx).into_status()
|
||||
}
|
||||
pub(crate) extern "system" fn get_max_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.operator.max_version() as _
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_input_memory_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs[index].memory_type.into()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_input_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs[index].characteristic.into()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_output_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs[index].characteristic.into()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_input_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs.len()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_output_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs.len()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_input_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs[index]
|
||||
.r#type
|
||||
.map(|c| c.into())
|
||||
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_output_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs[index]
|
||||
.r#type
|
||||
.map(|c| c.into())
|
||||
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_variadic_input_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_min_arity)
|
||||
.unwrap_or(1)
|
||||
.try_into()
|
||||
.expect("input minimum arity overflows i32")
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_variadic_input_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.inputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_homogeneity)
|
||||
.unwrap_or(false)
|
||||
.into()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_variadic_output_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_min_arity)
|
||||
.unwrap_or(1)
|
||||
.try_into()
|
||||
.expect("output minimum arity overflows i32")
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn get_variadic_output_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
|
||||
let safe = Self::safe(op);
|
||||
safe.outputs
|
||||
.iter()
|
||||
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
|
||||
.and_then(|c| c.variadic_homogeneity)
|
||||
.unwrap_or(false)
|
||||
.into()
|
||||
}
|
||||
|
||||
pub(crate) extern "system" fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
|
||||
let safe = Self::safe(op);
|
||||
let mut ctx = ShapeInferenceContext { ptr: ctx };
|
||||
safe.operator.infer_shape(&mut ctx).into_status()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,7 +314,7 @@ impl AsPointer for KernelContext {
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: usize) {
|
||||
extern "system" fn parallel_for_cb(user_data: *mut c_void, iterator: usize) {
|
||||
let executor = unsafe { &*user_data.cast::<Box<dyn Fn(usize) + Sync + Send>>() };
|
||||
executor(iterator)
|
||||
}
|
||||
|
||||
@@ -133,30 +133,28 @@ pub(crate) struct AsyncInferenceContext<'r, 's> {
|
||||
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue>
|
||||
}
|
||||
|
||||
crate::extern_system_fn! {
|
||||
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
|
||||
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };
|
||||
pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
|
||||
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };
|
||||
|
||||
// Reconvert name ptrs to CString so drop impl is called and memory is freed
|
||||
for p in ctx.input_name_ptrs {
|
||||
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
|
||||
}
|
||||
|
||||
if let Err(e) = crate::error::status_to_result(status) {
|
||||
ctx.inner.emplace_value(Err(e));
|
||||
ctx.inner.wake();
|
||||
return;
|
||||
}
|
||||
|
||||
let outputs: Vec<Value> = ctx
|
||||
.output_value_ptrs
|
||||
.into_iter()
|
||||
.map(|tensor_ptr| unsafe {
|
||||
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner)))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs)));
|
||||
ctx.inner.wake();
|
||||
// Reconvert name ptrs to CString so drop impl is called and memory is freed
|
||||
for p in ctx.input_name_ptrs {
|
||||
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
|
||||
}
|
||||
|
||||
if let Err(e) = crate::error::status_to_result(status) {
|
||||
ctx.inner.emplace_value(Err(e));
|
||||
ctx.inner.wake();
|
||||
return;
|
||||
}
|
||||
|
||||
let outputs: Vec<Value> = ctx
|
||||
.output_value_ptrs
|
||||
.into_iter()
|
||||
.map(|tensor_ptr| unsafe {
|
||||
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner)))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs)));
|
||||
ctx.inner.wake();
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ use crate::{
|
||||
AsPointer, char_p_to_string,
|
||||
environment::Environment,
|
||||
error::{Error, ErrorCode, Result, assert_non_null_pointer, status_to_result},
|
||||
extern_system_fn,
|
||||
io_binding::IoBinding,
|
||||
memory::Allocator,
|
||||
metadata::ModelMetadata,
|
||||
@@ -558,7 +557,7 @@ mod dangerous {
|
||||
}
|
||||
|
||||
fn extract_io_count(
|
||||
f: extern_system_fn! { unsafe fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus },
|
||||
f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus,
|
||||
session_ptr: NonNull<ort_sys::OrtSession>
|
||||
) -> Result<usize> {
|
||||
let mut num_nodes = 0;
|
||||
@@ -590,12 +589,7 @@ mod dangerous {
|
||||
}
|
||||
|
||||
fn extract_io_name(
|
||||
f: extern_system_fn! { unsafe fn(
|
||||
*const ort_sys::OrtSession,
|
||||
usize,
|
||||
*mut ort_sys::OrtAllocator,
|
||||
*mut *mut c_char,
|
||||
) -> *mut ort_sys::OrtStatus },
|
||||
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> *mut ort_sys::OrtStatus,
|
||||
session_ptr: NonNull<ort_sys::OrtSession>,
|
||||
allocator: &Allocator,
|
||||
i: usize
|
||||
@@ -624,11 +618,7 @@ mod dangerous {
|
||||
}
|
||||
|
||||
fn extract_io(
|
||||
f: extern_system_fn! { unsafe fn(
|
||||
*const ort_sys::OrtSession,
|
||||
usize,
|
||||
*mut *mut ort_sys::OrtTypeInfo,
|
||||
) -> *mut ort_sys::OrtStatus },
|
||||
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> *mut ort_sys::OrtStatus,
|
||||
session_ptr: NonNull<ort_sys::OrtSession>,
|
||||
i: usize
|
||||
) -> Result<ValueType> {
|
||||
|
||||
Reference in New Issue
Block a user