mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: DirectML
This commit is contained in:
@@ -12,6 +12,8 @@ extern "C" {
|
||||
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_Dnnl(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
|
||||
#[cfg(feature = "coreml")]
|
||||
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_CoreML(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr;
|
||||
#[cfg(feature = "directml")]
|
||||
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -63,7 +65,8 @@ impl ExecutionProvider {
|
||||
acl = "AclExecutionProvider",
|
||||
dnnl = "DnnlExecutionProvider",
|
||||
onednn = "DnnlExecutionProvider",
|
||||
coreml = "CoreMLExecutionProvider"
|
||||
coreml = "CoreMLExecutionProvider",
|
||||
directml = "DmlExecutionProvider"
|
||||
}
|
||||
|
||||
pub fn is_available(&self) -> bool {
|
||||
@@ -98,6 +101,10 @@ impl ExecutionProvider {
|
||||
///
|
||||
/// Supported backends: CPU, ACL, oneDNN
|
||||
pub fn with_use_arena(bool) = use_arena;
|
||||
/// The device ID to initialize the execution provider on.
|
||||
///
|
||||
/// Supported backends: DirectML
|
||||
pub fn with_device_id(i32) = device_id;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,6 +194,15 @@ pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, ex
|
||||
return; // EP found
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "directml")]
|
||||
"DmlExecutionProvider" => {
|
||||
let device_id = init_args.get("device_id").map_or(0, |s| s.parse::<i32>().unwrap_or(0));
|
||||
// TODO: extended options with OrtSessionOptionsAppendExecutionProviderEx_DML
|
||||
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_DML(options, device_id.into()) };
|
||||
if status_to_result_and_log("DirectML", status).is_ok() {
|
||||
return; // EP found
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user