refactor: use "proper" MemTypes

This commit is contained in:
Carson M
2023-02-12 23:03:52 -06:00
parent 2364c5d510
commit 3b74b6293a
2 changed files with 6 additions and 5 deletions

View File

@@ -424,7 +424,7 @@ impl SessionBuilder {
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator; nonNull(allocator_ptr)];
let memory_info = MemoryInfo::new(AllocatorType::Device, MemType::Default)?;
let memory_info = MemoryInfo::new(AllocatorType::Device, MemType::CPUOutput)?;
// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
@@ -557,9 +557,10 @@ impl Session {
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
// The C API expects pointers for the arrays (pointers to C-arrays)
let input_memory_info = MemoryInfo::new(AllocatorType::Device, MemType::Default)?;
let input_ort_tensors: Vec<InputOrtTensor> = input_arrays
.iter()
.map(|input_tensor| InputOrtTensor::from_input_tensor(&self.memory_info, self.allocator_ptr, input_tensor))
.map(|input_tensor| InputOrtTensor::from_input_tensor(&input_memory_info, self.allocator_ptr, input_tensor))
.collect::<OrtResult<Vec<InputOrtTensor>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors.iter().map(|input_array_ort| input_array_ort.c_ptr()).collect();
@@ -578,7 +579,7 @@ impl Session {
) -> OrtError::SessionRun
];
let memory_info_ref = &self.memory_info;
let output_memory_info = &self.memory_info;
let outputs: OrtResult<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| {
@@ -595,7 +596,7 @@ impl Session {
})
}?;
Ok(DynOrtTensor::new(tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), len as _, data_type))
Ok(DynOrtTensor::new(tensor_ptr, output_memory_info, ndarray::IxDyn(&dims), len as _, data_type))
})
.collect();

View File

@@ -56,8 +56,8 @@ pub enum TensorElementDataType {
String = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
/// Boolean, equivalent to Rust's `bool`.
Bool = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
#[cfg(feature = "half")]
/// 16-bit floating point number, equivalent to `half::f16` (requires the `half` crate).
#[cfg(feature = "half")]
Float16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
/// 64-bit floating point number, equivalent to Rust's `f64`. Also known as `double`.
Float64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,