mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
100 lines
2.6 KiB
Rust
100 lines
2.6 KiB
Rust
use ort::{
|
|
operator::{
|
|
Operator, OperatorDomain,
|
|
io::{OperatorInput, OperatorOutput},
|
|
kernel::{Kernel, KernelAttributes, KernelContext}
|
|
},
|
|
session::Session,
|
|
value::{Tensor, TensorElementType}
|
|
};
|
|
|
|
struct CustomOpOne;
|
|
|
|
impl Operator for CustomOpOne {
|
|
fn name(&self) -> &str {
|
|
"CustomOpOne"
|
|
}
|
|
|
|
fn inputs(&self) -> Vec<OperatorInput> {
|
|
vec![OperatorInput::required(TensorElementType::Float32), OperatorInput::required(TensorElementType::Float32)]
|
|
}
|
|
|
|
fn outputs(&self) -> Vec<OperatorOutput> {
|
|
vec![OperatorOutput::required(TensorElementType::Float32)]
|
|
}
|
|
|
|
fn create_kernel(&self, _: &KernelAttributes) -> ort::Result<Box<dyn Kernel>> {
|
|
Ok(Box::new(|ctx: &KernelContext| {
|
|
let x = ctx.input(0)?.unwrap();
|
|
let y = ctx.input(1)?.unwrap();
|
|
let (x_shape, x) = x.try_extract_tensor::<f32>()?;
|
|
let (y_shape, y) = y.try_extract_tensor::<f32>()?;
|
|
|
|
let mut z = ctx.output(0, x_shape.to_vec())?.unwrap();
|
|
let (_, z_ref) = z.try_extract_tensor_mut::<f32>()?;
|
|
for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize {
|
|
if i % 2 == 0 {
|
|
z_ref[i] = x[i];
|
|
} else {
|
|
z_ref[i] = y[i];
|
|
}
|
|
}
|
|
Ok(())
|
|
}))
|
|
}
|
|
}
|
|
|
|
struct CustomOpTwo;
|
|
|
|
impl Operator for CustomOpTwo {
|
|
fn name(&self) -> &str {
|
|
"CustomOpTwo"
|
|
}
|
|
fn inputs(&self) -> Vec<OperatorInput> {
|
|
vec![OperatorInput::required(TensorElementType::Float32)]
|
|
}
|
|
|
|
fn outputs(&self) -> Vec<OperatorOutput> {
|
|
vec![OperatorOutput::required(TensorElementType::Int32)]
|
|
}
|
|
|
|
fn create_kernel(&self, _: &KernelAttributes) -> ort::Result<Box<dyn Kernel>> {
|
|
Ok(Box::new(|ctx: &KernelContext| {
|
|
let x = ctx.input(0)?.unwrap();
|
|
let (x_shape, x) = x.try_extract_tensor::<f32>()?;
|
|
let mut z = ctx.output(0, x_shape.to_vec())?.unwrap();
|
|
let (_, z_ref) = z.try_extract_tensor_mut::<i32>()?;
|
|
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize {
|
|
z_ref[i] = (x[i] * i as f32) as i32;
|
|
}
|
|
Ok(())
|
|
}))
|
|
}
|
|
}
|
|
|
|
fn main() -> ort::Result<()> {
|
|
let mut session = Session::builder()?
|
|
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
|
|
.commit_from_file("tests/data/custom_op_test.onnx")?;
|
|
|
|
let allocator = session.allocator();
|
|
let mut value1 = Tensor::<f32>::new(allocator, [3_usize, 5])?;
|
|
{
|
|
let (_, data) = value1.extract_tensor_mut();
|
|
for datum in data {
|
|
*datum = 0.;
|
|
}
|
|
}
|
|
let mut value2 = Tensor::<f32>::new(allocator, [3_usize, 5])?;
|
|
{
|
|
let (_, data) = value2.extract_tensor_mut();
|
|
for datum in data {
|
|
*datum = 1.;
|
|
}
|
|
}
|
|
let values = session.run(ort::inputs![&value1, &value2])?;
|
|
println!("{:?}", values[0].try_extract_array::<i32>()?);
|
|
|
|
Ok(())
|
|
}
|