diff --git a/src/value/type.rs b/src/value/type.rs index e4d36da..7411439 100644 --- a/src/value/type.rs +++ b/src/value/type.rs @@ -182,8 +182,15 @@ impl ValueType { let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_type_info)]); ortsys![@editor: unsafe CreateTensorTypeInfo(tensor_type_info, &mut info_ptr)?]; } - Self::Map { .. } => { - todo!(); + Self::Map { key, value } => { + let value_type_info = ValueType::Tensor { + ty: *value, + shape: Shape::new([-1]), + dimension_symbols: SymbolicDimensions::empty(1) + } + .to_type_info()?; + let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(value_type_info)]); + ortsys![@editor: unsafe CreateMapTypeInfo((*key).into(), value_type_info, &mut info_ptr)?]; } Self::Sequence(ty) => { let el_type = ty.to_type_info()?; @@ -439,7 +446,7 @@ mod tests { }; #[test] - fn test_to_from_tensor_info() -> crate::Result<()> { + fn test_tensor_to_from_tensor_info() -> crate::Result<()> { let ty = ValueType::Tensor { ty: TensorElementType::Float32, shape: Shape::new([-1, 32, 4, 32]), @@ -452,4 +459,18 @@ mod tests { Ok(()) } + + #[test] + #[cfg(feature = "api-22")] + fn test_map_to_from_type_info() -> crate::Result<()> { + let ty = ValueType::Map { + key: TensorElementType::Float32, + value: TensorElementType::String + }; + let ty_ptr = NonNull::new(ty.to_type_info()?).expect(""); + let ty_d = unsafe { ValueType::from_type_info(ty_ptr) }; + assert_eq!(ty, ty_d); + + Ok(()) + } }