diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index fd3c9c3..c12c6b1 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -11,7 +11,7 @@ use crate::{ error::{Error, Result}, memory::Allocator, ortsys, - value::DynValueTypeMarker + value::DynValue }; pub trait SequenceValueTypeMarker: ValueTypeMarker { @@ -74,7 +74,7 @@ pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType>; pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType>; impl Value { - pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(&'s self) -> Result>> { + pub fn try_extract_sequence(&self) -> Result>> { match self.dtype() { ValueType::Sequence(_) => { let allocator = Allocator::default(); @@ -149,7 +149,7 @@ impl Value Value> { - pub fn extract_sequence<'s>(&'s self) -> Vec> { + pub fn extract_sequence(&self) -> Vec> { self.try_extract_sequence().expect("Failed to extract sequence") } @@ -162,17 +162,19 @@ impl Value bool { - let mut len = 0; - ortsys![unsafe GetValueCount(self.ptr(), &mut len).expect("infallible")]; - len == 0 + self.len() == 0 } - pub fn get(&self, index: usize) -> Option> { + pub fn get(&self, index: usize) -> Option> { extract_from_sequence(self.ptr(), index, &Allocator::default()) .ok() .and_then(|x| x.downcast().ok()) } + pub fn iter(&self) -> impl ExactSizeIterator> { + (0..self.len()).map(|i| self.get(i).expect("infallible")) + } + /// Converts from a strongly-typed [`Sequence`] to a type-erased [`DynSequence`]. #[inline] pub fn upcast(self) -> DynSequence { @@ -198,11 +200,38 @@ impl Value(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result> { +fn extract_from_sequence(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result { let mut value_ptr = ptr::null_mut(); ortsys![unsafe GetValue(ptr, i as _, allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)]; - - let mut value = ValueRef::new(unsafe { Value::from_ptr(value_ptr, None) }); - value.upgradable = false; - Ok(value) + Ok(unsafe { Value::from_ptr(value_ptr, None) }) +} + +pub struct IntoIter { + value: Value>, + i: usize +} + +impl Iterator for IntoIter { + type Item = Value; + + fn next(&mut self) -> Option { + let val = self.value.get(self.i); + self.i += 1; + val + } +} + +impl ExactSizeIterator for IntoIter { + fn len(&self) -> usize { + self.value.len() + } +} + +impl IntoIterator for Value> { + type Item = Value; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { value: self, i: 0 } + } } diff --git a/src/value/mod.rs b/src/value/mod.rs index e23415a..807bac9 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -337,7 +337,7 @@ impl Value { &self.inner.dtype } - /// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. + /// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. Takes ownership of `ptr`. /// /// If the value belongs to a session (i.e. if it is the result of an inference run), you must provide the /// [`SharedSessionInner`] (acquired from [`Session::inner`](crate::session::Session::inner)). This ensures the