feat: sequence iterators

This commit is contained in:
Carson M.
2026-03-11 14:49:56 -05:00
parent 82e6d652e1
commit 4df4b618b9
2 changed files with 42 additions and 13 deletions

View File

@@ -11,7 +11,7 @@ use crate::{
error::{Error, Result},
memory::Allocator,
ortsys,
value::DynValueTypeMarker
value::DynValue
};
pub trait SequenceValueTypeMarker: ValueTypeMarker {
@@ -74,7 +74,7 @@ pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType<T>>;
pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType<T>>;
impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(&'s self) -> Result<Vec<ValueRef<'s, OtherType>>> {
pub fn try_extract_sequence<OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(&self) -> Result<Vec<Value<OtherType>>> {
match self.dtype() {
ValueType::Sequence(_) => {
let allocator = Allocator::default();
@@ -149,7 +149,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<Se
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValueType<T>> {
pub fn extract_sequence<'s>(&'s self) -> Vec<ValueRef<'s, T>> {
pub fn extract_sequence(&self) -> Vec<Value<T>> {
self.try_extract_sequence().expect("Failed to extract sequence")
}
@@ -162,17 +162,19 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
#[inline]
pub fn is_empty(&self) -> bool {
let mut len = 0;
ortsys![unsafe GetValueCount(self.ptr(), &mut len).expect("infallible")];
len == 0
self.len() == 0
}
pub fn get(&self, index: usize) -> Option<ValueRef<'_, T>> {
pub fn get(&self, index: usize) -> Option<Value<T>> {
extract_from_sequence(self.ptr(), index, &Allocator::default())
.ok()
.and_then(|x| x.downcast().ok())
}
pub fn iter(&self) -> impl ExactSizeIterator<Item = Value<T>> {
(0..self.len()).map(|i| self.get(i).expect("infallible"))
}
/// Converts from a strongly-typed [`Sequence<T>`] to a type-erased [`DynSequence`].
#[inline]
pub fn upcast(self) -> DynSequence {
@@ -198,11 +200,38 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
}
}
fn extract_from_sequence<'s>(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result<ValueRef<'s, DynValueTypeMarker>> {
fn extract_from_sequence(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result<DynValue> {
let mut value_ptr = ptr::null_mut();
ortsys![unsafe GetValue(ptr, i as _, allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)];
let mut value = ValueRef::new(unsafe { Value::from_ptr(value_ptr, None) });
value.upgradable = false;
Ok(value)
Ok(unsafe { Value::from_ptr(value_ptr, None) })
}
pub struct IntoIter<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> {
value: Value<SequenceValueType<T>>,
i: usize
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Iterator for IntoIter<T> {
type Item = Value<T>;
fn next(&mut self) -> Option<Self::Item> {
let val = self.value.get(self.i);
self.i += 1;
val
}
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> ExactSizeIterator for IntoIter<T> {
fn len(&self) -> usize {
self.value.len()
}
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> IntoIterator for Value<SequenceValueType<T>> {
type Item = Value<T>;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter { value: self, i: 0 }
}
}

View File

@@ -337,7 +337,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
&self.inner.dtype
}
/// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer.
/// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. Takes ownership of `ptr`.
///
/// If the value belongs to a session (i.e. if it is the result of an inference run), you must provide the
/// [`SharedSessionInner`] (acquired from [`Session::inner`](crate::session::Session::inner)). This ensures the