test: improve coverage for operator

This commit is contained in:
Carson M.
2026-03-20 19:43:57 -05:00
parent 85f6074a48
commit dc786564b3
4 changed files with 243 additions and 0 deletions

View File

@@ -119,3 +119,207 @@ fn test_custom_ops() -> crate::Result<()> {
Ok(())
}
struct AttrTesterIntFloat;
impl Operator for AttrTesterIntFloat {
fn name(&self) -> &str {
"AttrTesterIntFloat"
}
fn inputs(&self) -> Vec<OperatorInput> {
vec![OperatorInput::required(TensorElementType::Float32)]
}
fn outputs(&self) -> Vec<OperatorOutput> {
vec![OperatorOutput::required(TensorElementType::Float32)]
}
fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> {
assert!(matches!(ctx.attr("a_int"), Ok(1_i64)));
assert!(matches!(ctx.attr("a_float"), Ok(2.0_f32)));
assert!(matches!(ctx.attr::<Vec<i64>>("ints").as_deref(), Ok(&[3, 4, 5])));
assert!(matches!(ctx.attr::<Vec<f32>>("floats").as_deref(), Ok(&[6., 7., 8.])));
ctx.set_output(0, &ctx.inputs()[0])?;
Ok(())
}
fn create_kernel(&self, _: &KernelAttributes) -> crate::Result<Box<dyn Kernel>> {
Ok(Box::new(|ctx: &KernelContext| {
let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?;
let (x_shape, x) = x.try_extract_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?;
let (_, z_ref) = z.try_extract_tensor_mut::<f32>()?;
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
z_ref[i] = x[i] * 2.;
}
Ok(())
}))
}
}
struct AttrTesterString;
impl Operator for AttrTesterString {
fn name(&self) -> &str {
"AttrTesterString"
}
fn inputs(&self) -> Vec<OperatorInput> {
vec![OperatorInput::required(TensorElementType::Float32)]
}
fn outputs(&self) -> Vec<OperatorOutput> {
vec![OperatorOutput::required(TensorElementType::Float32)]
}
fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> {
assert!(matches!(ctx.attr::<String>("a_string").as_deref(), Ok("iamastring")));
ctx.set_output(0, &ctx.inputs()[0])?;
Ok(())
}
fn create_kernel(&self, _: &KernelAttributes) -> crate::Result<Box<dyn Kernel>> {
Ok(Box::new(|ctx: &KernelContext| {
let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?;
let (x_shape, x) = x.try_extract_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?;
let (_, z_ref) = z.try_extract_tensor_mut::<f32>()?;
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
z_ref[i] = x[i] * 3.;
}
Ok(())
}))
}
}
#[test]
fn test_op_attrs() -> crate::Result<()> {
let mut session = Session::builder()?
.with_operators(OperatorDomain::new("test.customop")?.add(AttrTesterIntFloat)?.add(AttrTesterString)?)?
.commit_from_file("tests/data/attr_tester.onnx")?;
let value1 = Tensor::from_array(([5], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0]))?;
let values = session.run(crate::inputs!["input_0" => &value1])?;
assert_eq!(values[0].try_extract_tensor::<f32>()?.1, [6.0, 12.0, 18.0, 24.0, 30.0]);
Ok(())
}
struct CopyTensorArrayAllVariadic;
impl Operator for CopyTensorArrayAllVariadic {
fn name(&self) -> &str {
"CopyTensorArrayAllVariadic"
}
fn inputs(&self) -> Vec<OperatorInput> {
vec![OperatorInput::variadic(1).homogenous(TensorElementType::Float32)]
}
fn outputs(&self) -> Vec<OperatorOutput> {
vec![OperatorOutput::variadic(1).homogenous(TensorElementType::Float32)]
}
fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> {
for (i, input) in ctx.inputs().into_iter().enumerate() {
ctx.set_output(i, &input)?;
}
Ok(())
}
fn create_kernel(&self, _: &KernelAttributes) -> crate::Result<Box<dyn Kernel>> {
Ok(Box::new(|ctx: &KernelContext| copy_variadic(0, ctx)))
}
}
fn copy_variadic(start: usize, ctx: &KernelContext) -> crate::Result<()> {
for i in start..ctx.num_inputs()? {
let input = ctx.input(i)?.ok_or_else(|| crate::Error::new("missing input"))?;
let mut output = ctx
.output(i, input.shape().clone())?
.ok_or_else(|| crate::Error::new(format!("failed to allocate output {i}")))?;
output.try_extract_tensor_mut::<f32>()?.1.copy_from_slice(input.try_extract_tensor()?.1);
}
Ok(())
}
struct CopyTensorArrayCombined;
impl Operator for CopyTensorArrayCombined {
fn name(&self) -> &str {
"CopyTensorArrayCombined"
}
fn inputs(&self) -> Vec<OperatorInput> {
vec![OperatorInput::optional(TensorElementType::Float32), OperatorInput::variadic(0).homogenous(TensorElementType::Float32)]
}
fn outputs(&self) -> Vec<OperatorOutput> {
vec![OperatorOutput::optional(TensorElementType::Float32), OperatorOutput::variadic(0).homogenous(TensorElementType::Float32)]
}
fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> {
for (i, input) in ctx.inputs().into_iter().enumerate() {
ctx.set_output(i, &input)?;
}
Ok(())
}
fn create_kernel(&self, _: &KernelAttributes) -> crate::Result<Box<dyn Kernel>> {
Ok(Box::new(|ctx: &KernelContext| {
if let Ok(Some(input)) = ctx.input(0) {
let mut output = ctx
.output(0, input.shape().clone())?
.ok_or_else(|| crate::Error::new("failed to allocate output 0"))?;
output.try_extract_tensor_mut::<f32>()?.1.copy_from_slice(input.try_extract_tensor()?.1);
}
copy_variadic(1, ctx)
}))
}
}
#[test]
fn test_variadic_io() -> crate::Result<()> {
let ops = Arc::new(
OperatorDomain::new("test.customop")?
.add(CopyTensorArrayAllVariadic)?
.add(CopyTensorArrayCombined)?
);
let mut session = Session::builder()?
.with_operators(Arc::clone(&ops))?
.commit_from_file("tests/data/copy_2_inputs_2_outputs.onnx")?;
let input0 = Tensor::from_array(([15], vec![1.1_f32, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0, 11.1, 12.2, 13.3, 14.4, 15.5]))?;
let input1 = Tensor::from_array(([15], vec![15.5_f32, 14.4, 13.3, 12.2, 11.1, 10.0, 9.9, 8.8, 7.7, 6.6, 5.5, 4.4, 3.3, 2.2, 1.1]))?;
let values = session.run(crate::inputs!["input_0" => &input0, "input_1" => &input1])?;
assert_eq!(values[0].try_extract_tensor::<f32>()?.1, input0.extract_tensor().1);
assert_eq!(values[1].try_extract_tensor::<f32>()?.1, input1.extract_tensor().1);
let mut session = Session::builder()?
.with_operators(Arc::clone(&ops))?
.commit_from_file("tests/data/copy_3_inputs_3_outputs.onnx")?;
let input2 = Tensor::from_array(([15], vec![6.6_f32, 7.7, 8.8, 9.9, 10.0, 1.1, 2.2, 3.3, 4.4, 5.5, 11.1, 12.2, 13.3, 14.4, 15.5]))?;
let values = session.run(crate::inputs![
"input_0" => &input0,
"input_1" => &input1,
"input_2" => &input2
])?;
assert_eq!(values[0].try_extract_tensor::<f32>()?.1, input0.extract_tensor().1);
assert_eq!(values[1].try_extract_tensor::<f32>()?.1, input1.extract_tensor().1);
assert_eq!(values[2].try_extract_tensor::<f32>()?.1, input2.extract_tensor().1);
Ok(())
}

BIN
tests/data/attr_tester.onnx Normal file

Binary file not shown.

View File

@@ -0,0 +1,16 @@

d
input_0
input_1output_0output_1copy_tensor_array"CopyTensorArrayAllVariadic: test.customopgraphZ
input_0

ÿÿÿÿÿÿÿÿÿZ
input_1

ÿÿÿÿÿÿÿÿÿb
output_0

ÿÿÿÿÿÿÿÿÿb
output_1

ÿÿÿÿÿÿÿÿÿB

View File

@@ -0,0 +1,23 @@

t
input_0
input_1
input_2output_0output_1output_2copy_tensor_array"CopyTensorArrayCombined: test.customopgraphZ
input_0

ÿÿÿÿÿÿÿÿÿZ
input_1

ÿÿÿÿÿÿÿÿÿZ
input_2

ÿÿÿÿÿÿÿÿÿb
output_0

ÿÿÿÿÿÿÿÿÿb
output_1

ÿÿÿÿÿÿÿÿÿb
output_2

ÿÿÿÿÿÿÿÿÿB