fix(io-binding): take output values by (mut) reference

This commit is contained in:
Carson M.
2023-11-06 15:15:14 -06:00
parent 53904f5cff
commit 87a124e462

View File

@@ -21,31 +21,35 @@ use crate::{memory::MemoryInfo, ortfree, ortsys, session::output::SessionOutputs
pub struct IoBinding<'s> {
pub(crate) ptr: *mut ort_sys::OrtIoBinding,
session: &'s Session,
values: Vec<Value>
// Hold onto input values until we run the `IoBinding`.
input_values: Vec<Value>
}
impl<'s> IoBinding<'s> {
pub(crate) fn new(session: &'s Session) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtIoBinding = ptr::null_mut();
ortsys![unsafe CreateIoBinding(session.inner.session_ptr, &mut ptr) -> Error::CreateIoBinding; nonNull(ptr)];
Ok(Self { ptr, session, values: Vec::new() })
Ok(Self {
ptr,
session,
input_values: Vec::new()
})
}
/// Bind a [`Value`] to a session input.
pub fn bind_input<'a, 'b: 'a, S: AsRef<str> + Clone + Debug>(&'a mut self, name: S, ort_value: Value) -> Result<()> {
pub fn bind_input<S: AsRef<str> + Clone + Debug>(&mut self, name: S, ort_value: Value) -> Result<()> {
let name = name.as_ref();
let cname = CString::new(name)?;
ortsys![unsafe BindInput(self.ptr, cname.as_ptr(), ort_value.ptr()) -> Error::BindInput];
self.values.push(ort_value);
self.input_values.push(ort_value);
Ok(())
}
/// Bind a session output to a pre-allocated [`Value`].
pub fn bind_output<'a, 'b: 'a, S: AsRef<str> + Clone + Debug>(&'a mut self, name: S, ort_value: Value) -> Result<()> {
pub fn bind_output<'a, 'b: 'a, S: AsRef<str> + Clone + Debug>(&'a mut self, name: S, ort_value: &'b mut Value) -> Result<()> {
let name = name.as_ref();
let cname = CString::new(name)?;
ortsys![unsafe BindOutput(self.ptr, cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput];
self.values.push(ort_value);
Ok(())
}
@@ -60,12 +64,12 @@ impl<'s> IoBinding<'s> {
pub fn run(&mut self) -> Result<SessionOutputs> {
let run_options_ptr: *const ort_sys::OrtRunOptions = std::ptr::null();
ortsys![unsafe RunWithBinding(self.session.inner.session_ptr, run_options_ptr, self.ptr) -> Error::SessionRunWithIoBinding];
self.values.clear();
self.input_values.clear();
self.outputs()
}
pub fn clean(&mut self) {
self.values.clear();
self.input_values.clear();
}
pub fn outputs(&self) -> Result<SessionOutputs> {