refactor: make Outlet wrap OrtValueInfo

This commit is contained in:
Carson M.
2026-03-02 18:27:17 -06:00
parent a02122dd66
commit 771e1a5c4a
2 changed files with 62 additions and 9 deletions

View File

@@ -160,14 +160,14 @@ impl Graph {
}
pub fn set_inputs(&mut self, inputs: impl IntoIterator<Item = Outlet>) -> Result<()> {
let inputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = inputs.into_iter().map(|input| input.into_editor_value_info()).collect::<Result<_>>()?;
let inputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = inputs.into_iter().filter_map(|input| input.into_value_info_ptr()).collect();
// this takes ownership of the OrtValueInfos so no need to free those
ortsys![@editor: unsafe SetGraphInputs(self.0, inputs.as_ptr() as *mut _, inputs.len())?];
Ok(())
}
pub fn set_outputs(&mut self, outputs: impl IntoIterator<Item = Outlet>) -> Result<()> {
let outputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = outputs.into_iter().map(|input| input.into_editor_value_info()).collect::<Result<_>>()?;
let outputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = outputs.into_iter().filter_map(|input| input.into_value_info_ptr()).collect();
// this takes ownership of the OrtValueInfos so no need to free those
ortsys![@editor: unsafe SetGraphOutputs(self.0, outputs.as_ptr() as *mut _, outputs.len())?];
Ok(())

View File

@@ -294,12 +294,44 @@ impl fmt::Display for ValueType {
#[derive(Debug)]
pub struct Outlet {
name: String,
dtype: ValueType
dtype: ValueType,
// Outlet is used for many things, but a ValueInfo can only be created if the Model Editor API is available, which it sometimes may not be.
value_info: Option<NonNull<ort_sys::OrtValueInfo>>,
drop: bool
}
impl Outlet {
pub fn new<S: Into<String>>(name: S, dtype: ValueType) -> Self {
Self { name: name.into(), dtype }
let name = name.into();
let value_info = Self::make_value_info(&name, &dtype);
Self {
name,
dtype,
value_info,
drop: value_info.is_some()
}
}
#[cfg(feature = "api-22")]
pub(crate) unsafe fn from_raw(raw: NonNull<ort_sys::OrtValueInfo>, drop: bool) -> Result<Self> {
let mut name = ptr::null();
ortsys![unsafe GetValueInfoName(raw.as_ptr(), &mut name)?];
let name = if !name.is_null() {
unsafe { CStr::from_ptr(name) }.to_str().map_or_else(|_| String::new(), str::to_string)
} else {
String::new()
};
let mut type_info = ptr::null();
ortsys![unsafe GetValueInfoTypeInfo(raw.as_ptr(), &mut type_info)?; nonNull(type_info)];
let dtype = unsafe { ValueType::from_type_info(type_info) };
Ok(Self {
name,
dtype,
value_info: Some(raw),
drop
})
}
#[inline]
@@ -313,16 +345,37 @@ impl Outlet {
}
#[cfg(feature = "api-22")]
pub(crate) fn into_editor_value_info(self) -> Result<NonNull<ort_sys::OrtValueInfo>> {
let type_info = self.dtype.to_type_info()?;
pub(crate) fn make_value_info(name: &str, dtype: &ValueType) -> Option<NonNull<ort_sys::OrtValueInfo>> {
let type_info = dtype.to_type_info().ok()?;
let _guard = run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(type_info)]);
let ptr = with_cstr(self.name.as_bytes(), &|name| {
with_cstr(name.as_bytes(), &|name| {
let mut ptr: *mut ort_sys::OrtValueInfo = ptr::null_mut();
ortsys![@editor: unsafe CreateValueInfo(name.as_ptr(), type_info, &mut ptr)?; nonNull(ptr)];
Ok(ptr)
})?;
Ok(ptr)
})
.ok()
}
#[cfg(not(feature = "api-22"))]
pub(crate) fn make_value_info(_name: &str, _dtype: &ValueType) -> Option<NonNull<ort_sys::OrtValueInfo>> {
None
}
#[inline]
pub(crate) fn into_value_info_ptr(mut self) -> Option<NonNull<ort_sys::OrtValueInfo>> {
let value_info = self.value_info;
self.drop = false;
value_info
}
}
impl Drop for Outlet {
fn drop(&mut self) {
#[cfg(feature = "api-22")]
if self.drop {
ortsys![unsafe ReleaseValueInfo(self.value_info.expect("OrtValueInfo should not be null").as_ptr())];
}
}
}