mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
fix(io-binding): take output values by (mut) reference
This commit is contained in:
@@ -21,31 +21,35 @@ use crate::{memory::MemoryInfo, ortfree, ortsys, session::output::SessionOutputs
|
||||
pub struct IoBinding<'s> {
|
||||
pub(crate) ptr: *mut ort_sys::OrtIoBinding,
|
||||
session: &'s Session,
|
||||
values: Vec<Value>
|
||||
// Hold onto input values until we run the `IoBinding`.
|
||||
input_values: Vec<Value>
|
||||
}
|
||||
|
||||
impl<'s> IoBinding<'s> {
|
||||
pub(crate) fn new(session: &'s Session) -> Result<Self> {
|
||||
let mut ptr: *mut ort_sys::OrtIoBinding = ptr::null_mut();
|
||||
ortsys![unsafe CreateIoBinding(session.inner.session_ptr, &mut ptr) -> Error::CreateIoBinding; nonNull(ptr)];
|
||||
Ok(Self { ptr, session, values: Vec::new() })
|
||||
Ok(Self {
|
||||
ptr,
|
||||
session,
|
||||
input_values: Vec::new()
|
||||
})
|
||||
}
|
||||
|
||||
/// Bind a [`Value`] to a session input.
|
||||
pub fn bind_input<'a, 'b: 'a, S: AsRef<str> + Clone + Debug>(&'a mut self, name: S, ort_value: Value) -> Result<()> {
|
||||
pub fn bind_input<S: AsRef<str> + Clone + Debug>(&mut self, name: S, ort_value: Value) -> Result<()> {
|
||||
let name = name.as_ref();
|
||||
let cname = CString::new(name)?;
|
||||
ortsys![unsafe BindInput(self.ptr, cname.as_ptr(), ort_value.ptr()) -> Error::BindInput];
|
||||
self.values.push(ort_value);
|
||||
self.input_values.push(ort_value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bind a session output to a pre-allocated [`Value`].
|
||||
pub fn bind_output<'a, 'b: 'a, S: AsRef<str> + Clone + Debug>(&'a mut self, name: S, ort_value: Value) -> Result<()> {
|
||||
pub fn bind_output<'a, 'b: 'a, S: AsRef<str> + Clone + Debug>(&'a mut self, name: S, ort_value: &'b mut Value) -> Result<()> {
|
||||
let name = name.as_ref();
|
||||
let cname = CString::new(name)?;
|
||||
ortsys![unsafe BindOutput(self.ptr, cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput];
|
||||
self.values.push(ort_value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -60,12 +64,12 @@ impl<'s> IoBinding<'s> {
|
||||
pub fn run(&mut self) -> Result<SessionOutputs> {
|
||||
let run_options_ptr: *const ort_sys::OrtRunOptions = std::ptr::null();
|
||||
ortsys![unsafe RunWithBinding(self.session.inner.session_ptr, run_options_ptr, self.ptr) -> Error::SessionRunWithIoBinding];
|
||||
self.values.clear();
|
||||
self.input_values.clear();
|
||||
self.outputs()
|
||||
}
|
||||
|
||||
pub fn clean(&mut self) {
|
||||
self.values.clear();
|
||||
self.input_values.clear();
|
||||
}
|
||||
|
||||
pub fn outputs(&self) -> Result<SessionOutputs> {
|
||||
|
||||
Reference in New Issue
Block a user