mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: threadpool access in operator kernels
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
ffi::{c_char, CString},
|
||||
ffi::{c_char, c_void, CString},
|
||||
ops::{Deref, DerefMut},
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
@@ -249,6 +249,15 @@ impl KernelContext {
|
||||
Ok(NonNull::new(resource_ptr))
|
||||
}
|
||||
|
||||
pub fn par_for<F>(&self, total: usize, max_num_batches: usize, f: F) -> Result<()>
|
||||
where
|
||||
F: Fn(usize) + Sync + Send
|
||||
{
|
||||
let executor = Box::new(f) as Box<dyn Fn(usize) + Sync + Send>;
|
||||
ortsys![unsafe KernelContext_ParallelFor(self.ptr.as_ptr(), Some(parallel_for_cb), total as _, max_num_batches as _, &executor as *const _ as *mut c_void)?];
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: STATUS_ACCESS_VIOLATION inside `KernelContext_GetScratchBuffer`. gonna assume this one is just an internal ONNX
|
||||
// Runtime bug.
|
||||
//
|
||||
@@ -280,3 +289,8 @@ impl KernelContext {
|
||||
Ok(NonNull::new(stream_ptr))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: ort_sys::size_t) {
|
||||
let executor = unsafe { &*user_data.cast::<Box<dyn Fn(usize) + Sync + Send>>() };
|
||||
executor(iterator as _)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ const IGNORED_SYMBOLS = new Set<string>([
|
||||
'RegisterCustomOpsUsingFunction',
|
||||
'SessionOptionsAppendExecutionProvider_CUDA', // we use V2
|
||||
'SessionOptionsAppendExecutionProvider_TensorRT', // we use V2
|
||||
'GetValueType', // we get value types via GetTypeInfo -> GetOnnxTypeFromTypeInfo, which is equivalent
|
||||
'SetLanguageProjection', // someday we shall have `ORT_PROJECTION_RUST`, but alas, today is not that day...
|
||||
|
||||
// we use allocator APIs directly on the Allocator struct
|
||||
|
||||
Reference in New Issue
Block a user