mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
fix: implement to_type_info for maps, closes #553
This commit is contained in:
@@ -182,8 +182,15 @@ impl ValueType {
|
||||
let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_type_info)]);
|
||||
ortsys![@editor: unsafe CreateTensorTypeInfo(tensor_type_info, &mut info_ptr)?];
|
||||
}
|
||||
Self::Map { .. } => {
|
||||
todo!();
|
||||
Self::Map { key, value } => {
|
||||
let value_type_info = ValueType::Tensor {
|
||||
ty: *value,
|
||||
shape: Shape::new([-1]),
|
||||
dimension_symbols: SymbolicDimensions::empty(1)
|
||||
}
|
||||
.to_type_info()?;
|
||||
let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(value_type_info)]);
|
||||
ortsys![@editor: unsafe CreateMapTypeInfo((*key).into(), value_type_info, &mut info_ptr)?];
|
||||
}
|
||||
Self::Sequence(ty) => {
|
||||
let el_type = ty.to_type_info()?;
|
||||
@@ -439,7 +446,7 @@ mod tests {
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_to_from_tensor_info() -> crate::Result<()> {
|
||||
fn test_tensor_to_from_tensor_info() -> crate::Result<()> {
|
||||
let ty = ValueType::Tensor {
|
||||
ty: TensorElementType::Float32,
|
||||
shape: Shape::new([-1, 32, 4, 32]),
|
||||
@@ -452,4 +459,18 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "api-22")]
|
||||
fn test_map_to_from_type_info() -> crate::Result<()> {
|
||||
let ty = ValueType::Map {
|
||||
key: TensorElementType::Float32,
|
||||
value: TensorElementType::String
|
||||
};
|
||||
let ty_ptr = NonNull::new(ty.to_type_info()?).expect("");
|
||||
let ty_d = unsafe { ValueType::from_type_info(ty_ptr) };
|
||||
assert_eq!(ty, ty_d);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user