refactor(sys): replace _system macro with extern "system"

This commit is contained in:
Carson M.
2024-11-29 14:01:00 -06:00
parent 393711eb56
commit 25881a35a7
8 changed files with 1047 additions and 1118 deletions

View File

@@ -1,5 +1,5 @@
use std::{
env, fs, io,
env, fs,
path::{Path, PathBuf}
};

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ use tracing::{Level, debug};
#[cfg(feature = "load-dynamic")]
use crate::G_ORT_DYLIB_PATH;
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, extern_system_fn, ortsys};
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys};
struct EnvironmentSingleton {
lock: RwLock<Option<Arc<Environment>>>
@@ -172,7 +172,7 @@ pub trait ThreadManager {
fn join(thread: Self::Thread) -> crate::Result<()>;
}
pub(crate) unsafe extern "C" fn thread_create<T: ThreadManager + Any>(
pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
ort_custom_thread_creation_options: *mut c_void,
ort_thread_worker_fn: ort_sys::OrtThreadWorkerFn,
ort_worker_fn_param: *mut c_void
@@ -201,7 +201,7 @@ pub(crate) unsafe extern "C" fn thread_create<T: ThreadManager + Any>(
}
}
pub(crate) unsafe extern "C" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
pub(crate) unsafe extern "system" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<<T as ThreadManager>::Thread>());
if let Err(e) = <T as ThreadManager>::join(*handle) {
tracing::error!("Failed to join thread using manager: {e}");
@@ -395,29 +395,29 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
EnvironmentBuilder::new()
}
extern_system_fn! {
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, _: *const c_char, id: *const c_char, code_location: *const c_char, message: *const c_char) {
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
assert_ne!(id, ptr::null());
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
pub(crate) extern "system" fn custom_logger(
_params: *mut ffi::c_void,
severity: ort_sys::OrtLoggingLevel,
_: *const c_char,
id: *const c_char,
code_location: *const c_char,
message: *const c_char
) {
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
assert_ne!(id, ptr::null());
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");
let span = tracing::span!(
Level::TRACE,
"ort",
id = id,
location = code_location
);
let span = tracing::span!(Level::TRACE, "ort", id = id, location = code_location);
match severity {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::DEBUG, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::INFO, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::WARN, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL=> tracing::event!(parent: &span, Level::ERROR, "{message}")
}
match severity {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::DEBUG, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::INFO, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::WARN, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, Level::ERROR, "{message}")
}
}

View File

@@ -49,23 +49,6 @@ pub use self::{
error::{Error, ErrorCode, Result}
};
#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
}
#[cfg(all(target_arch = "x86", target_os = "windows"))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
}
pub(crate) use extern_system_fn;
/// The minor version of ONNX Runtime used by this version of `ort`.
pub const MINOR_VERSION: u32 = ort_sys::ORT_API_VERSION;
@@ -167,7 +150,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
let base: *const ort_sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());
let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } =
let get_version_string: unsafe extern "system" fn() -> *const c_char =
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
@@ -187,7 +170,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
),
std::cmp::Ordering::Equal => {}
};
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } =
let get_api: unsafe extern "system" fn(u32) -> *const ort_sys::OrtApi =
(*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
let api: *const ort_sys::OrtApi = get_api(ort_sys::ORT_API_VERSION);
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
@@ -196,7 +179,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
unsafe {
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } =
let get_api: unsafe extern "system" fn(u32) -> *const ort_sys::OrtApi =
(*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
let api: *const ort_sys::OrtApi = get_api(ort_sys::ORT_API_VERSION);
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))

View File

@@ -5,7 +5,7 @@ use super::{
io::{self, InputOutputCharacteristic},
kernel::{Kernel, KernelAttributes, KernelContext}
};
use crate::{Result, error::IntoStatus, extern_system_fn};
use crate::{Result, error::IntoStatus};
#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later
pub(crate) struct BoundOperator {
@@ -62,168 +62,140 @@ impl BoundOperator {
})
}
unsafe fn safe<'a>(op: *const ort_sys::OrtCustomOp) -> &'a BoundOperator {
&*op.cast()
fn safe<'a>(op: *const ort_sys::OrtCustomOp) -> &'a BoundOperator {
unsafe { &*op.cast() }
}
extern_system_fn! {
pub(crate) unsafe fn create_kernel(
op: *const ort_sys::OrtCustomOp,
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
let safe = Self::safe(op);
let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
e => return e.into_status()
};
*kernel_ptr = (Box::leak(Box::new(kernel)) as *mut Box<dyn Kernel>).cast();
Ok(()).into_status()
}
pub(crate) extern "system" fn create_kernel(
op: *const ort_sys::OrtCustomOp,
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
let safe = Self::safe(op);
let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
e => return e.into_status()
};
unsafe { *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut Box<dyn Kernel>).cast() };
Ok(()).into_status()
}
extern_system_fn! {
pub(crate) unsafe fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
let context = KernelContext::new(context);
unsafe { &mut *kernel_ptr.cast::<Box<dyn Kernel>>() }.compute(&context).into_status()
}
pub(crate) extern "system" fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
let context = KernelContext::new(context);
unsafe { &mut *kernel_ptr.cast::<Box<dyn Kernel>>() }.compute(&context).into_status()
}
extern_system_fn! {
pub(crate) unsafe fn destroy_kernel(op_kernel: *mut ort_sys::c_void) {
drop(Box::from_raw(op_kernel.cast::<Box<dyn Kernel>>()));
}
pub(crate) extern "system" fn destroy_kernel(op_kernel: *mut ort_sys::c_void) {
drop(unsafe { Box::from_raw(op_kernel.cast::<Box<dyn Kernel>>()) });
}
extern_system_fn! {
pub(crate) unsafe fn get_name(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.name.as_ptr()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_execution_provider_type(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
}
pub(crate) extern "system" fn get_name(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.name.as_ptr()
}
extern_system_fn! {
pub(crate) unsafe fn get_min_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.operator.min_version() as _
}
}
extern_system_fn! {
pub(crate) unsafe fn get_max_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.operator.max_version() as _
}
pub(crate) extern "system" fn get_execution_provider_type(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
}
extern_system_fn! {
pub(crate) unsafe fn get_input_memory_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType {
let safe = Self::safe(op);
safe.inputs[index].memory_type.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_input_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
let safe = Self::safe(op);
safe.inputs[index].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_output_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
let safe = Self::safe(op);
safe.outputs[index].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_input_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
let safe = Self::safe(op);
safe.inputs.len()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_output_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
let safe = Self::safe(op);
safe.outputs.len()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_input_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
let safe = Self::safe(op);
safe.inputs[index]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn get_output_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
let safe = Self::safe(op);
safe.outputs[index]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn get_variadic_input_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.inputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("input minimum arity overflows i32")
}
}
extern_system_fn! {
pub(crate) unsafe fn get_variadic_input_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.inputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn get_variadic_output_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.outputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("output minimum arity overflows i32")
}
}
extern_system_fn! {
pub(crate) unsafe fn get_variadic_output_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.outputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
pub(crate) extern "system" fn get_min_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.operator.min_version() as _
}
extern_system_fn! {
pub(crate) unsafe fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
let safe = Self::safe(op);
let mut ctx = ShapeInferenceContext {
ptr: ctx
};
safe.operator.infer_shape(&mut ctx).into_status()
}
pub(crate) extern "system" fn get_max_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.operator.max_version() as _
}
pub(crate) extern "system" fn get_input_memory_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType {
let safe = Self::safe(op);
safe.inputs[index].memory_type.into()
}
pub(crate) extern "system" fn get_input_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
let safe = Self::safe(op);
safe.inputs[index].characteristic.into()
}
pub(crate) extern "system" fn get_output_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
let safe = Self::safe(op);
safe.outputs[index].characteristic.into()
}
pub(crate) extern "system" fn get_input_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
let safe = Self::safe(op);
safe.inputs.len()
}
pub(crate) extern "system" fn get_output_type_count(op: *const ort_sys::OrtCustomOp) -> usize {
let safe = Self::safe(op);
safe.outputs.len()
}
pub(crate) extern "system" fn get_input_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
let safe = Self::safe(op);
safe.inputs[index]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
pub(crate) extern "system" fn get_output_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
let safe = Self::safe(op);
safe.outputs[index]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
pub(crate) extern "system" fn get_variadic_input_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.inputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("input minimum arity overflows i32")
}
pub(crate) extern "system" fn get_variadic_input_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.inputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
pub(crate) extern "system" fn get_variadic_output_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.outputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("output minimum arity overflows i32")
}
pub(crate) extern "system" fn get_variadic_output_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
let safe = Self::safe(op);
safe.outputs
.iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
pub(crate) extern "system" fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
let safe = Self::safe(op);
let mut ctx = ShapeInferenceContext { ptr: ctx };
safe.operator.infer_shape(&mut ctx).into_status()
}
}

View File

@@ -314,7 +314,7 @@ impl AsPointer for KernelContext {
}
}
extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: usize) {
extern "system" fn parallel_for_cb(user_data: *mut c_void, iterator: usize) {
let executor = unsafe { &*user_data.cast::<Box<dyn Fn(usize) + Sync + Send>>() };
executor(iterator)
}

View File

@@ -133,30 +133,28 @@ pub(crate) struct AsyncInferenceContext<'r, 's> {
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue>
}
crate::extern_system_fn! {
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };
pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };
// Reconvert name ptrs to CString so drop impl is called and memory is freed
for p in ctx.input_name_ptrs {
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
}
if let Err(e) = crate::error::status_to_result(status) {
ctx.inner.emplace_value(Err(e));
ctx.inner.wake();
return;
}
let outputs: Vec<Value> = ctx
.output_value_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner)))
})
.collect();
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs)));
ctx.inner.wake();
// Reconvert name ptrs to CString so drop impl is called and memory is freed
for p in ctx.input_name_ptrs {
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
}
if let Err(e) = crate::error::status_to_result(status) {
ctx.inner.emplace_value(Err(e));
ctx.inner.wake();
return;
}
let outputs: Vec<Value> = ctx
.output_value_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner)))
})
.collect();
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs)));
ctx.inner.wake();
}

View File

@@ -24,7 +24,6 @@ use crate::{
AsPointer, char_p_to_string,
environment::Environment,
error::{Error, ErrorCode, Result, assert_non_null_pointer, status_to_result},
extern_system_fn,
io_binding::IoBinding,
memory::Allocator,
metadata::ModelMetadata,
@@ -558,7 +557,7 @@ mod dangerous {
}
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus },
f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus,
session_ptr: NonNull<ort_sys::OrtSession>
) -> Result<usize> {
let mut num_nodes = 0;
@@ -590,12 +589,7 @@ mod dangerous {
}
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const ort_sys::OrtSession,
usize,
*mut ort_sys::OrtAllocator,
*mut *mut c_char,
) -> *mut ort_sys::OrtStatus },
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> *mut ort_sys::OrtStatus,
session_ptr: NonNull<ort_sys::OrtSession>,
allocator: &Allocator,
i: usize
@@ -624,11 +618,7 @@ mod dangerous {
}
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const ort_sys::OrtSession,
usize,
*mut *mut ort_sys::OrtTypeInfo,
) -> *mut ort_sys::OrtStatus },
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> *mut ort_sys::OrtStatus,
session_ptr: NonNull<ort_sys::OrtSession>,
i: usize
) -> Result<ValueType> {