feat: DirectML

This commit is contained in:
Carson M
2023-01-17 19:34:44 -06:00
parent f7a617fa2b
commit 1d4fcd9528

View File

@@ -12,6 +12,8 @@ extern "C" {
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_Dnnl(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
#[cfg(feature = "coreml")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_CoreML(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr;
#[cfg(feature = "directml")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr;
}
#[derive(Debug, Clone)]
@@ -63,7 +65,8 @@ impl ExecutionProvider {
acl = "AclExecutionProvider",
dnnl = "DnnlExecutionProvider",
onednn = "DnnlExecutionProvider",
coreml = "CoreMLExecutionProvider"
coreml = "CoreMLExecutionProvider",
directml = "DmlExecutionProvider"
}
pub fn is_available(&self) -> bool {
@@ -98,6 +101,10 @@ impl ExecutionProvider {
///
/// Supported backends: CPU, ACL, oneDNN
pub fn with_use_arena(bool) = use_arena;
/// The device ID to initialize the execution provider on.
///
/// Supported backends: DirectML
pub fn with_device_id(i32) = device_id;
}
}
@@ -187,6 +194,15 @@ pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, ex
return; // EP found
}
}
#[cfg(feature = "directml")]
"DmlExecutionProvider" => {
let device_id = init_args.get("device_id").map_or(0, |s| s.parse::<i32>().unwrap_or(0));
// TODO: extended options with OrtSessionOptionsAppendExecutionProviderEx_DML
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_DML(options, device_id.into()) };
if status_to_result_and_log("DirectML", status).is_ok() {
return; // EP found
}
}
_ => {}
};
}