diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index f74035d..61a3cf2 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -10,95 +10,72 @@ use ort::{ }; struct CustomOpOne; -struct CustomOpOneKernel; impl Operator for CustomOpOne { - type Kernel = CustomOpOneKernel; - - fn name() -> &'static str { + fn name(&self) -> &str { "CustomOpOne" } - fn create_kernel(_: &KernelAttributes) -> ort::Result { - Ok(CustomOpOneKernel) - } - - fn inputs() -> Vec { + fn inputs(&self) -> Vec { vec![OperatorInput::required(TensorElementType::Float32), OperatorInput::required(TensorElementType::Float32)] } - fn outputs() -> Vec { + fn outputs(&self) -> Vec { vec![OperatorOutput::required(TensorElementType::Float32)] } - fn get_infer_shape_function() -> Option> { - Some(Box::new(|ctx| { - let inputs = ctx.inputs(); - ctx.set_output(0, &inputs[0])?; + fn create_kernel(&self, _: &KernelAttributes) -> ort::Result> { + Ok(Box::new(|ctx: &KernelContext| { + let x = ctx.input(0)?.unwrap(); + let y = ctx.input(1)?.unwrap(); + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let (y_shape, y) = y.try_extract_raw_tensor::()?; + + let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; + for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { + if i % 2 == 0 { + z_ref[i] = x[i]; + } else { + z_ref[i] = y[i]; + } + } Ok(()) })) } } -impl Kernel for CustomOpOneKernel { - fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> { - let x = ctx.input(0)?.unwrap(); - let y = ctx.input(1)?.unwrap(); - let (x_shape, x) = x.try_extract_raw_tensor::()?; - let (y_shape, y) = y.try_extract_raw_tensor::()?; - - let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); - let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { - if i % 2 == 0 { - z_ref[i] = x[i]; - } else { - z_ref[i] = y[i]; - } - } - Ok(()) - } -} - struct CustomOpTwo; -struct CustomOpTwoKernel; impl Operator for CustomOpTwo { - type Kernel = CustomOpTwoKernel; - - fn name() -> &'static str { + fn name(&self) -> &str { "CustomOpTwo" } - - fn create_kernel(_: &KernelAttributes) -> ort::Result { - Ok(CustomOpTwoKernel) - } - - fn inputs() -> Vec { + fn inputs(&self) -> Vec { vec![OperatorInput::required(TensorElementType::Float32)] } - fn outputs() -> Vec { + fn outputs(&self) -> Vec { vec![OperatorOutput::required(TensorElementType::Int32)] } -} -impl Kernel for CustomOpTwoKernel { - fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> { - let x = ctx.input(0)?.unwrap(); - let (x_shape, x) = x.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); - let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { - z_ref[i] = (x[i] * i as f32) as i32; - } - Ok(()) + fn create_kernel(&self, _: &KernelAttributes) -> ort::Result> { + Ok(Box::new(|ctx: &KernelContext| { + let x = ctx.input(0)?.unwrap(); + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { + z_ref[i] = (x[i] * i as f32) as i32; + } + Ok(()) + })) } } fn main() -> ort::Result<()> { let session = Session::builder()? - .with_operators(OperatorDomain::new("test.customop")?.add::()?.add::()?)? + .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? .commit_from_file("tests/data/custom_op_test.onnx")?; let values = session.run(ort::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; diff --git a/src/operator/bound.rs b/src/operator/bound.rs index 9f6c737..ba9ba34 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -1,168 +1,180 @@ -use std::{ - ffi::CString, - marker::PhantomData, - ptr::{self, NonNull} -}; +use std::{ffi::CString, ptr}; use super::{ - DummyOperator, Operator, ShapeInferenceContext, - io::InputOutputCharacteristic, + Operator, ShapeInferenceContext, + io::{self, InputOutputCharacteristic}, kernel::{Kernel, KernelAttributes, KernelContext} }; -use crate::{error::IntoStatus, extern_system_fn}; +use crate::{Result, error::IntoStatus, extern_system_fn}; #[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later -pub(crate) struct BoundOperator { +pub(crate) struct BoundOperator { implementation: ort_sys::OrtCustomOp, name: CString, execution_provider_type: Option, - _operator: PhantomData + inputs: Vec, + outputs: Vec, + operator: Box } +unsafe impl Send for BoundOperator {} + #[allow(non_snake_case, clippy::unnecessary_cast)] -impl BoundOperator { - pub(crate) fn new(name: CString, execution_provider_type: Option) -> Self { - Self { +impl BoundOperator { + pub(crate) fn new(operator: O) -> Result { + let name = CString::new(operator.name())?; + let execution_provider_type = operator.execution_provider_type().map(CString::new).transpose()?; + + Ok(Self { implementation: ort_sys::OrtCustomOp { version: ort_sys::ORT_API_VERSION, - GetStartVersion: Some(BoundOperator::::GetStartVersion), - GetEndVersion: Some(BoundOperator::::GetEndVersion), + GetStartVersion: Some(BoundOperator::get_min_version), + GetEndVersion: Some(BoundOperator::get_max_version), CreateKernel: None, - CreateKernelV2: Some(BoundOperator::::CreateKernelV2), - GetInputCharacteristic: Some(BoundOperator::::GetInputCharacteristic), - GetInputMemoryType: Some(BoundOperator::::GetInputMemoryType), - GetInputType: Some(BoundOperator::::GetInputType), - GetInputTypeCount: Some(BoundOperator::::GetInputTypeCount), - GetName: Some(BoundOperator::::GetName), - GetExecutionProviderType: Some(BoundOperator::::GetExecutionProviderType), - GetOutputCharacteristic: Some(BoundOperator::::GetOutputCharacteristic), - GetOutputType: Some(BoundOperator::::GetOutputType), - GetOutputTypeCount: Some(BoundOperator::::GetOutputTypeCount), - GetVariadicInputHomogeneity: Some(BoundOperator::::GetVariadicInputHomogeneity), - GetVariadicInputMinArity: Some(BoundOperator::::GetVariadicInputMinArity), - GetVariadicOutputHomogeneity: Some(BoundOperator::::GetVariadicOutputHomogeneity), - GetVariadicOutputMinArity: Some(BoundOperator::::GetVariadicOutputMinArity), + CreateKernelV2: Some(BoundOperator::create_kernel), + GetInputCharacteristic: Some(BoundOperator::get_input_characteristic), + GetInputMemoryType: Some(BoundOperator::get_input_memory_type), + GetInputType: Some(BoundOperator::get_input_type), + GetInputTypeCount: Some(BoundOperator::get_input_type_count), + GetName: Some(BoundOperator::get_name), + GetExecutionProviderType: Some(BoundOperator::get_execution_provider_type), + GetOutputCharacteristic: Some(BoundOperator::get_output_characteristic), + GetOutputType: Some(BoundOperator::get_output_type), + GetOutputTypeCount: Some(BoundOperator::get_output_type_count), + GetVariadicInputHomogeneity: Some(BoundOperator::get_variadic_input_homogeneity), + GetVariadicInputMinArity: Some(BoundOperator::get_variadic_input_min_arity), + GetVariadicOutputHomogeneity: Some(BoundOperator::get_variadic_output_homogeneity), + GetVariadicOutputMinArity: Some(BoundOperator::get_variadic_output_min_arity), GetAliasMap: None, ReleaseAliasMap: None, GetMayInplace: None, ReleaseMayInplace: None, - InferOutputShapeFn: if O::get_infer_shape_function().is_some() { - Some(BoundOperator::::InferOutputShapeFn) - } else { - None - }, + InferOutputShapeFn: Some(BoundOperator::infer_output_shape), KernelCompute: None, - KernelComputeV2: Some(BoundOperator::::ComputeKernelV2), - KernelDestroy: Some(BoundOperator::::KernelDestroy) + KernelComputeV2: Some(BoundOperator::compute_kernel), + KernelDestroy: Some(BoundOperator::destroy_kernel) }, name, execution_provider_type, - _operator: PhantomData - } + inputs: operator.inputs(), + outputs: operator.outputs(), + operator: Box::new(operator) + }) } - unsafe fn safe<'a>(op: *const ort_sys::OrtCustomOp) -> &'a BoundOperator { + unsafe fn safe<'a>(op: *const ort_sys::OrtCustomOp) -> &'a BoundOperator { &*op.cast() } extern_system_fn! { - pub(crate) unsafe fn CreateKernelV2( - _: *const ort_sys::OrtCustomOp, + pub(crate) unsafe fn create_kernel( + op: *const ort_sys::OrtCustomOp, _: *const ort_sys::OrtApi, info: *const ort_sys::OrtKernelInfo, kernel_ptr: *mut *mut ort_sys::c_void ) -> *mut ort_sys::OrtStatus { - let kernel = match O::create_kernel(&KernelAttributes::new(info)) { + let safe = Self::safe(op); + let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) { Ok(kernel) => kernel, e => return e.into_status() }; - *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast(); + *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut Box).cast(); Ok(()).into_status() } } extern_system_fn! { - pub(crate) unsafe fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { + pub(crate) unsafe fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { let context = KernelContext::new(context); - O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::() }, &context).into_status() + unsafe { &mut *kernel_ptr.cast::>() }.compute(&context).into_status() } } extern_system_fn! { - pub(crate) unsafe fn KernelDestroy(op_kernel: *mut ort_sys::c_void) { - drop(Box::from_raw(op_kernel.cast::())); + pub(crate) unsafe fn destroy_kernel(op_kernel: *mut ort_sys::c_void) { + drop(Box::from_raw(op_kernel.cast::>())); } } extern_system_fn! { - pub(crate) unsafe fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { + pub(crate) unsafe fn get_name(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { let safe = Self::safe(op); safe.name.as_ptr() } } extern_system_fn! { - pub(crate) unsafe fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { + pub(crate) unsafe fn get_execution_provider_type(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { let safe = Self::safe(op); safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) } } extern_system_fn! { - pub(crate) unsafe fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::min_version() + pub(crate) unsafe fn get_min_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + let safe = Self::safe(op); + safe.operator.min_version() as _ } } extern_system_fn! { - pub(crate) unsafe fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::max_version() + pub(crate) unsafe fn get_max_version(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + let safe = Self::safe(op); + safe.operator.max_version() as _ } } extern_system_fn! { - pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType { - O::inputs()[index].memory_type.into() + pub(crate) unsafe fn get_input_memory_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType { + let safe = Self::safe(op); + safe.inputs[index].memory_type.into() } } extern_system_fn! { - pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic { - O::inputs()[index].characteristic.into() + pub(crate) unsafe fn get_input_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic { + let safe = Self::safe(op); + safe.inputs[index].characteristic.into() } } extern_system_fn! { - pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic { - O::outputs()[index].characteristic.into() + pub(crate) unsafe fn get_output_characteristic(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic { + let safe = Self::safe(op); + safe.outputs[index].characteristic.into() } } extern_system_fn! { - pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> usize { - O::inputs().len() + pub(crate) unsafe fn get_input_type_count(op: *const ort_sys::OrtCustomOp) -> usize { + let safe = Self::safe(op); + safe.inputs.len() } } extern_system_fn! { - pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> usize { - O::outputs().len() + pub(crate) unsafe fn get_output_type_count(op: *const ort_sys::OrtCustomOp) -> usize { + let safe = Self::safe(op); + safe.outputs.len() } } extern_system_fn! { - pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType { - O::inputs()[index] + pub(crate) unsafe fn get_input_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType { + let safe = Self::safe(op); + safe.inputs[index] .r#type .map(|c| c.into()) .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) } } extern_system_fn! { - pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType { - O::outputs()[index] + pub(crate) unsafe fn get_output_type(op: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType { + let safe = Self::safe(op); + safe.outputs[index] .r#type .map(|c| c.into()) .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) } } extern_system_fn! { - pub(crate) unsafe fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::inputs() - .into_iter() + pub(crate) unsafe fn get_variadic_input_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + let safe = Self::safe(op); + safe.inputs + .iter() .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) .and_then(|c| c.variadic_min_arity) .unwrap_or(1) @@ -171,9 +183,10 @@ impl BoundOperator { } } extern_system_fn! { - pub(crate) unsafe fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::inputs() - .into_iter() + pub(crate) unsafe fn get_variadic_input_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + let safe = Self::safe(op); + safe.inputs + .iter() .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) .and_then(|c| c.variadic_homogeneity) .unwrap_or(false) @@ -181,9 +194,10 @@ impl BoundOperator { } } extern_system_fn! { - pub(crate) unsafe fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::outputs() - .into_iter() + pub(crate) unsafe fn get_variadic_output_min_arity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + let safe = Self::safe(op); + safe.outputs + .iter() .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) .and_then(|c| c.variadic_min_arity) .unwrap_or(1) @@ -192,9 +206,10 @@ impl BoundOperator { } } extern_system_fn! { - pub(crate) unsafe fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::outputs() - .into_iter() + pub(crate) unsafe fn get_variadic_output_homogeneity(op: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + let safe = Self::safe(op); + safe.outputs + .iter() .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) .and_then(|c| c.variadic_homogeneity) .unwrap_or(false) @@ -203,34 +218,12 @@ impl BoundOperator { } extern_system_fn! { - pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { + pub(crate) unsafe fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { + let safe = Self::safe(op); let mut ctx = ShapeInferenceContext { ptr: ctx }; - O::get_infer_shape_function().expect("missing infer shape function")(&mut ctx).into_status() + safe.operator.infer_shape(&mut ctx).into_status() } } } - -pub(crate) struct ErasedBoundOperator(NonNull<()>); - -unsafe impl Send for ErasedBoundOperator {} - -impl ErasedBoundOperator { - pub(crate) fn new(bound: BoundOperator) -> Self { - ErasedBoundOperator(NonNull::from(unsafe { - // horrible horrible horrible horrible horrible horrible horrible horrible horrible - &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) - })) - } - - pub(crate) fn op_ptr(&self) -> *mut ort_sys::OrtCustomOp { - self.0.as_ptr().cast() - } -} - -impl Drop for ErasedBoundOperator { - fn drop(&mut self) { - drop(unsafe { Box::from_raw(self.0.as_ptr().cast::>()) }); - } -} diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index c2cbdd7..9221f21 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -17,11 +17,12 @@ pub trait Kernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()>; } -pub(crate) struct DummyKernel; - -impl Kernel for DummyKernel { - fn compute(&mut self, _: &KernelContext) -> crate::Result<()> { - unimplemented!() +impl Kernel for F +where + F: FnMut(&KernelContext) -> crate::Result<()> +{ + fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()> { + self(ctx) } } diff --git a/src/operator/mod.rs b/src/operator/mod.rs index 7f58ddd..7c3c59c 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -12,9 +12,9 @@ pub mod kernel; mod tests; use self::{ - bound::{BoundOperator, ErasedBoundOperator}, + bound::BoundOperator, io::{OperatorInput, OperatorOutput}, - kernel::{DummyKernel, Kernel, KernelAttributes} + kernel::{Kernel, KernelAttributes} }; use crate::{ AsPointer, Error, @@ -32,10 +32,8 @@ use crate::{ /// [`Operator`] structs can return the same name in [`Operator::name`] so that they are usable as simply /// `my.domain:Sort` in the graph. pub trait Operator: Send { - type Kernel: Kernel; - /// Returns the name of the operator. - fn name() -> &'static str; + fn name(&self) -> &str; /// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`. /// @@ -46,49 +44,28 @@ pub trait Operator: Send { /// /// [`Tensor::data_ptr`]: crate::value::Tensor::data_ptr /// [`KernelContext::compute_stream`]: crate::operator::kernel::KernelContext::compute_stream - fn execution_provider_type() -> Option<&'static str> { + fn execution_provider_type(&self) -> Option<&str> { None } - fn inputs() -> Vec; - fn outputs() -> Vec; + fn inputs(&self) -> Vec; + fn outputs(&self) -> Vec; - fn create_kernel(attributes: &KernelAttributes) -> crate::Result; + fn create_kernel(&self, attributes: &KernelAttributes) -> crate::Result>; - fn min_version() -> ort_sys::c_int { + fn min_version(&self) -> i32 { 1 } - fn max_version() -> ort_sys::c_int { - ort_sys::c_int::MAX + fn max_version(&self) -> i32 { + i32::MAX } - fn get_infer_shape_function() -> Option> { - None + fn infer_shape(&self, ctx: &mut ShapeInferenceContext) -> crate::Result<()> { + let _ = ctx; + Ok(()) } } -/// Dummy type implementing [`Operator`] used by [`ErasedBoundOperator`] to cheat the type system. -struct DummyOperator; - -impl Operator for DummyOperator { - type Kernel = DummyKernel; - - fn name() -> &'static str { - unimplemented!() - } - fn create_kernel(_: &KernelAttributes) -> crate::Result { - unimplemented!() - } - fn inputs() -> Vec { - unimplemented!() - } - fn outputs() -> Vec { - unimplemented!() - } -} - -pub type InferShapeFn = dyn FnMut(&mut ShapeInferenceContext) -> crate::Result<()> + 'static; - pub struct ShapeInferenceContext { ptr: *mut ort_sys::OrtShapeInferContext } @@ -130,7 +107,8 @@ impl AsPointer for ShapeInferenceContext { pub struct OperatorDomain { ptr: NonNull, _name: CString, - operators: Vec + #[allow(clippy::vec_box)] + operators: Vec> } impl OperatorDomain { @@ -146,12 +124,11 @@ impl OperatorDomain { } #[allow(clippy::should_implement_trait)] - pub fn add(mut self) -> Result { - let name = O::name(); - - let bound = BoundOperator::::new(CString::new(name)?, O::execution_provider_type().map(CString::new).transpose()?); - let bound = ErasedBoundOperator::new(bound); - ortsys![unsafe CustomOpDomain_Add(self.ptr.as_ptr(), bound.op_ptr())?]; + pub fn add(mut self, operator: O) -> Result { + // `Box`ing the operator here because we move it into `self` immediately after registering it. Without `Box`, + // the pointer we pass to `CustomOpDomain_Add` would become invalid. + let bound = Box::new(BoundOperator::new(operator)?); + ortsys![unsafe CustomOpDomain_Add(self.ptr.as_ptr(), (&*bound as *const BoundOperator) as *mut _)?]; self.operators.push(bound); diff --git a/src/operator/tests.rs b/src/operator/tests.rs index 4c10cd7..8dd2d5e 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -12,88 +12,74 @@ use crate::{ }; struct CustomOpOne; -struct CustomOpOneKernel; impl Operator for CustomOpOne { - type Kernel = CustomOpOneKernel; - - fn name() -> &'static str { + fn name(&self) -> &str { "CustomOpOne" } - fn create_kernel(_: &KernelAttributes) -> Result { - Ok(CustomOpOneKernel) - } - - fn inputs() -> Vec { + fn inputs(&self) -> Vec { vec![OperatorInput::required(TensorElementType::Float32), OperatorInput::required(TensorElementType::Float32)] } - fn outputs() -> Vec { + fn outputs(&self) -> Vec { vec![OperatorOutput::required(TensorElementType::Float32)] } -} -impl Kernel for CustomOpOneKernel { - fn compute(&mut self, ctx: &KernelContext) -> Result<()> { - let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; - let y = ctx.input(1)?.ok_or_else(|| crate::Error::new("missing input"))?; - let (x_shape, x) = x.try_extract_raw_tensor::()?; - let (y_shape, y) = y.try_extract_raw_tensor::()?; + fn create_kernel(&self, _: &KernelAttributes) -> Result> { + Ok(Box::new(|ctx: &KernelContext| { + let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; + let y = ctx.input(1)?.ok_or_else(|| crate::Error::new("missing input"))?; + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let (y_shape, y) = y.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; - let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { - if i % 2 == 0 { - z_ref[i] = x[i]; - } else { - z_ref[i] = y[i]; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; + for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + if i % 2 == 0 { + z_ref[i] = x[i]; + } else { + z_ref[i] = y[i]; + } } - } - Ok(()) + Ok(()) + })) } } struct CustomOpTwo; -struct CustomOpTwoKernel; impl Operator for CustomOpTwo { - type Kernel = CustomOpTwoKernel; - - fn name() -> &'static str { + fn name(&self) -> &str { "CustomOpTwo" } - fn create_kernel(_: &KernelAttributes) -> crate::Result { - Ok(CustomOpTwoKernel) - } - - fn inputs() -> Vec { + fn inputs(&self) -> Vec { vec![OperatorInput::required(TensorElementType::Float32)] } - fn outputs() -> Vec { + fn outputs(&self) -> Vec { vec![OperatorOutput::required(TensorElementType::Int32)] } -} -impl Kernel for CustomOpTwoKernel { - fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()> { - let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; - let (x_shape, x) = x.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; - let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { - z_ref[i] = (x[i] * i as f32) as i32; - } - Ok(()) + fn create_kernel(&self, _: &KernelAttributes) -> crate::Result> { + Ok(Box::new(|ctx: &KernelContext| { + let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + z_ref[i] = (x[i] * i as f32) as i32; + } + Ok(()) + })) } } #[test] fn test_custom_ops() -> crate::Result<()> { let session = Session::builder()? - .with_operators(OperatorDomain::new("test.customop")?.add::()?.add::()?)? + .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? .commit_from_file("tests/data/custom_op_test.onnx")?; let values = session.run(crate::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?;