mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: make Outlet wrap OrtValueInfo
This commit is contained in:
@@ -160,14 +160,14 @@ impl Graph {
|
||||
}
|
||||
|
||||
pub fn set_inputs(&mut self, inputs: impl IntoIterator<Item = Outlet>) -> Result<()> {
|
||||
let inputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = inputs.into_iter().map(|input| input.into_editor_value_info()).collect::<Result<_>>()?;
|
||||
let inputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = inputs.into_iter().filter_map(|input| input.into_value_info_ptr()).collect();
|
||||
// this takes ownership of the OrtValueInfos so no need to free those
|
||||
ortsys![@editor: unsafe SetGraphInputs(self.0, inputs.as_ptr() as *mut _, inputs.len())?];
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_outputs(&mut self, outputs: impl IntoIterator<Item = Outlet>) -> Result<()> {
|
||||
let outputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = outputs.into_iter().map(|input| input.into_editor_value_info()).collect::<Result<_>>()?;
|
||||
let outputs: SmallVec<[NonNull<ort_sys::OrtValueInfo>; 4]> = outputs.into_iter().filter_map(|input| input.into_value_info_ptr()).collect();
|
||||
// this takes ownership of the OrtValueInfos so no need to free those
|
||||
ortsys![@editor: unsafe SetGraphOutputs(self.0, outputs.as_ptr() as *mut _, outputs.len())?];
|
||||
Ok(())
|
||||
|
||||
@@ -294,12 +294,44 @@ impl fmt::Display for ValueType {
|
||||
#[derive(Debug)]
|
||||
pub struct Outlet {
|
||||
name: String,
|
||||
dtype: ValueType
|
||||
dtype: ValueType,
|
||||
// Outlet is used for many things, but a ValueInfo can only be created if the Model Editor API is available, which it sometimes may not be.
|
||||
value_info: Option<NonNull<ort_sys::OrtValueInfo>>,
|
||||
drop: bool
|
||||
}
|
||||
|
||||
impl Outlet {
|
||||
pub fn new<S: Into<String>>(name: S, dtype: ValueType) -> Self {
|
||||
Self { name: name.into(), dtype }
|
||||
let name = name.into();
|
||||
let value_info = Self::make_value_info(&name, &dtype);
|
||||
Self {
|
||||
name,
|
||||
dtype,
|
||||
value_info,
|
||||
drop: value_info.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "api-22")]
|
||||
pub(crate) unsafe fn from_raw(raw: NonNull<ort_sys::OrtValueInfo>, drop: bool) -> Result<Self> {
|
||||
let mut name = ptr::null();
|
||||
ortsys![unsafe GetValueInfoName(raw.as_ptr(), &mut name)?];
|
||||
let name = if !name.is_null() {
|
||||
unsafe { CStr::from_ptr(name) }.to_str().map_or_else(|_| String::new(), str::to_string)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let mut type_info = ptr::null();
|
||||
ortsys![unsafe GetValueInfoTypeInfo(raw.as_ptr(), &mut type_info)?; nonNull(type_info)];
|
||||
let dtype = unsafe { ValueType::from_type_info(type_info) };
|
||||
|
||||
Ok(Self {
|
||||
name,
|
||||
dtype,
|
||||
value_info: Some(raw),
|
||||
drop
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -313,16 +345,37 @@ impl Outlet {
|
||||
}
|
||||
|
||||
#[cfg(feature = "api-22")]
|
||||
pub(crate) fn into_editor_value_info(self) -> Result<NonNull<ort_sys::OrtValueInfo>> {
|
||||
let type_info = self.dtype.to_type_info()?;
|
||||
pub(crate) fn make_value_info(name: &str, dtype: &ValueType) -> Option<NonNull<ort_sys::OrtValueInfo>> {
|
||||
let type_info = dtype.to_type_info().ok()?;
|
||||
let _guard = run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(type_info)]);
|
||||
|
||||
let ptr = with_cstr(self.name.as_bytes(), &|name| {
|
||||
with_cstr(name.as_bytes(), &|name| {
|
||||
let mut ptr: *mut ort_sys::OrtValueInfo = ptr::null_mut();
|
||||
ortsys![@editor: unsafe CreateValueInfo(name.as_ptr(), type_info, &mut ptr)?; nonNull(ptr)];
|
||||
Ok(ptr)
|
||||
})?;
|
||||
Ok(ptr)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "api-22"))]
|
||||
pub(crate) fn make_value_info(_name: &str, _dtype: &ValueType) -> Option<NonNull<ort_sys::OrtValueInfo>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn into_value_info_ptr(mut self) -> Option<NonNull<ort_sys::OrtValueInfo>> {
|
||||
let value_info = self.value_info;
|
||||
self.drop = false;
|
||||
value_info
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Outlet {
|
||||
fn drop(&mut self) {
|
||||
#[cfg(feature = "api-22")]
|
||||
if self.drop {
|
||||
ortsys![unsafe ReleaseValueInfo(self.value_info.expect("OrtValueInfo should not be null").as_ptr())];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user