mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: use "proper" MemTypes
This commit is contained in:
@@ -424,7 +424,7 @@ impl SessionBuilder {
|
||||
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
|
||||
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator; nonNull(allocator_ptr)];
|
||||
|
||||
let memory_info = MemoryInfo::new(AllocatorType::Device, MemType::Default)?;
|
||||
let memory_info = MemoryInfo::new(AllocatorType::Device, MemType::CPUOutput)?;
|
||||
|
||||
// Extract input and output properties
|
||||
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
|
||||
@@ -557,9 +557,10 @@ impl Session {
|
||||
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
|
||||
|
||||
// The C API expects pointers for the arrays (pointers to C-arrays)
|
||||
let input_memory_info = MemoryInfo::new(AllocatorType::Device, MemType::Default)?;
|
||||
let input_ort_tensors: Vec<InputOrtTensor> = input_arrays
|
||||
.iter()
|
||||
.map(|input_tensor| InputOrtTensor::from_input_tensor(&self.memory_info, self.allocator_ptr, input_tensor))
|
||||
.map(|input_tensor| InputOrtTensor::from_input_tensor(&input_memory_info, self.allocator_ptr, input_tensor))
|
||||
.collect::<OrtResult<Vec<InputOrtTensor>>>()?;
|
||||
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors.iter().map(|input_array_ort| input_array_ort.c_ptr()).collect();
|
||||
|
||||
@@ -578,7 +579,7 @@ impl Session {
|
||||
) -> OrtError::SessionRun
|
||||
];
|
||||
|
||||
let memory_info_ref = &self.memory_info;
|
||||
let output_memory_info = &self.memory_info;
|
||||
let outputs: OrtResult<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> = output_tensor_ptrs
|
||||
.into_iter()
|
||||
.map(|tensor_ptr| {
|
||||
@@ -595,7 +596,7 @@ impl Session {
|
||||
})
|
||||
}?;
|
||||
|
||||
Ok(DynOrtTensor::new(tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), len as _, data_type))
|
||||
Ok(DynOrtTensor::new(tensor_ptr, output_memory_info, ndarray::IxDyn(&dims), len as _, data_type))
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -56,8 +56,8 @@ pub enum TensorElementDataType {
|
||||
String = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
|
||||
/// Boolean, equivalent to Rust's `bool`.
|
||||
Bool = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
|
||||
#[cfg(feature = "half")]
|
||||
/// 16-bit floating point number, equivalent to `half::f16` (requires the `half` crate).
|
||||
#[cfg(feature = "half")]
|
||||
Float16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
|
||||
/// 64-bit floating point number, equivalent to Rust's `f64`. Also known as `double`.
|
||||
Float64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
|
||||
|
||||
Reference in New Issue
Block a user