mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: sequence iterators
This commit is contained in:
@@ -11,7 +11,7 @@ use crate::{
|
||||
error::{Error, Result},
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
value::DynValueTypeMarker
|
||||
value::DynValue
|
||||
};
|
||||
|
||||
pub trait SequenceValueTypeMarker: ValueTypeMarker {
|
||||
@@ -74,7 +74,7 @@ pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType<T>>;
|
||||
pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType<T>>;
|
||||
|
||||
impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
|
||||
pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(&'s self) -> Result<Vec<ValueRef<'s, OtherType>>> {
|
||||
pub fn try_extract_sequence<OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(&self) -> Result<Vec<Value<OtherType>>> {
|
||||
match self.dtype() {
|
||||
ValueType::Sequence(_) => {
|
||||
let allocator = Allocator::default();
|
||||
@@ -149,7 +149,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<Se
|
||||
}
|
||||
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValueType<T>> {
|
||||
pub fn extract_sequence<'s>(&'s self) -> Vec<ValueRef<'s, T>> {
|
||||
pub fn extract_sequence(&self) -> Vec<Value<T>> {
|
||||
self.try_extract_sequence().expect("Failed to extract sequence")
|
||||
}
|
||||
|
||||
@@ -162,17 +162,19 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
|
||||
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
let mut len = 0;
|
||||
ortsys![unsafe GetValueCount(self.ptr(), &mut len).expect("infallible")];
|
||||
len == 0
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
pub fn get(&self, index: usize) -> Option<ValueRef<'_, T>> {
|
||||
pub fn get(&self, index: usize) -> Option<Value<T>> {
|
||||
extract_from_sequence(self.ptr(), index, &Allocator::default())
|
||||
.ok()
|
||||
.and_then(|x| x.downcast().ok())
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl ExactSizeIterator<Item = Value<T>> {
|
||||
(0..self.len()).map(|i| self.get(i).expect("infallible"))
|
||||
}
|
||||
|
||||
/// Converts from a strongly-typed [`Sequence<T>`] to a type-erased [`DynSequence`].
|
||||
#[inline]
|
||||
pub fn upcast(self) -> DynSequence {
|
||||
@@ -198,11 +200,38 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_from_sequence<'s>(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result<ValueRef<'s, DynValueTypeMarker>> {
|
||||
fn extract_from_sequence(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result<DynValue> {
|
||||
let mut value_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetValue(ptr, i as _, allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)];
|
||||
|
||||
let mut value = ValueRef::new(unsafe { Value::from_ptr(value_ptr, None) });
|
||||
value.upgradable = false;
|
||||
Ok(value)
|
||||
Ok(unsafe { Value::from_ptr(value_ptr, None) })
|
||||
}
|
||||
|
||||
pub struct IntoIter<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> {
|
||||
value: Value<SequenceValueType<T>>,
|
||||
i: usize
|
||||
}
|
||||
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Iterator for IntoIter<T> {
|
||||
type Item = Value<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let val = self.value.get(self.i);
|
||||
self.i += 1;
|
||||
val
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> ExactSizeIterator for IntoIter<T> {
|
||||
fn len(&self) -> usize {
|
||||
self.value.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> IntoIterator for Value<SequenceValueType<T>> {
|
||||
type Item = Value<T>;
|
||||
type IntoIter = IntoIter<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
IntoIter { value: self, i: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -337,7 +337,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
&self.inner.dtype
|
||||
}
|
||||
|
||||
/// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer.
|
||||
/// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. Takes ownership of `ptr`.
|
||||
///
|
||||
/// If the value belongs to a session (i.e. if it is the result of an inference run), you must provide the
|
||||
/// [`SharedSessionInner`] (acquired from [`Session::inner`](crate::session::Session::inner)). This ensures the
|
||||
|
||||
Reference in New Issue
Block a user