mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
136 lines
3.9 KiB
Rust
136 lines
3.9 KiB
Rust
use std::os::raw::c_void;
|
|
|
|
use super::ExecutionProvider;
|
|
use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ROCmExecutionProvider {
|
|
device_id: i32,
|
|
miopen_conv_exhaustive_search: bool,
|
|
gpu_mem_limit: ort_sys::size_t,
|
|
arena_extend_strategy: ArenaExtendStrategy,
|
|
do_copy_in_default_stream: bool,
|
|
user_compute_stream: Option<*mut c_void>,
|
|
default_memory_arena_cfg: Option<*mut ort_sys::OrtArenaCfg>,
|
|
tunable_op_enable: bool,
|
|
tunable_op_tuning_enable: bool,
|
|
tunable_op_max_tuning_duration_ms: i32
|
|
}
|
|
|
|
unsafe impl Send for ROCmExecutionProvider {}
|
|
unsafe impl Sync for ROCmExecutionProvider {}
|
|
|
|
impl Default for ROCmExecutionProvider {
|
|
fn default() -> Self {
|
|
Self {
|
|
device_id: 0,
|
|
miopen_conv_exhaustive_search: false,
|
|
gpu_mem_limit: ort_sys::size_t::MAX,
|
|
arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo,
|
|
do_copy_in_default_stream: true,
|
|
user_compute_stream: None,
|
|
default_memory_arena_cfg: None,
|
|
tunable_op_enable: false,
|
|
tunable_op_tuning_enable: false,
|
|
tunable_op_max_tuning_duration_ms: 0
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ROCmExecutionProvider {
|
|
pub fn with_device_id(mut self, device_id: i32) -> Self {
|
|
self.device_id = device_id;
|
|
self
|
|
}
|
|
|
|
pub fn with_exhaustive_conv_search(mut self) -> Self {
|
|
self.miopen_conv_exhaustive_search = true;
|
|
self
|
|
}
|
|
|
|
pub fn with_mem_limit(mut self, limit: usize) -> Self {
|
|
self.gpu_mem_limit = limit as _;
|
|
self
|
|
}
|
|
|
|
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
|
|
self.arena_extend_strategy = strategy;
|
|
self
|
|
}
|
|
|
|
pub fn with_copy_in_default_stream(mut self, enable: bool) -> Self {
|
|
self.do_copy_in_default_stream = enable;
|
|
self
|
|
}
|
|
|
|
pub fn with_compute_stream(mut self, ptr: *mut c_void) -> Self {
|
|
self.user_compute_stream = Some(ptr);
|
|
self
|
|
}
|
|
|
|
pub fn with_default_memory_arena_cfg(mut self, cfg: *mut ort_sys::OrtArenaCfg) -> Self {
|
|
self.default_memory_arena_cfg = Some(cfg);
|
|
self
|
|
}
|
|
|
|
pub fn with_tunable_op(mut self, enable: bool) -> Self {
|
|
self.tunable_op_enable = enable;
|
|
self
|
|
}
|
|
|
|
pub fn with_tuning(mut self, enable: bool) -> Self {
|
|
self.tunable_op_tuning_enable = enable;
|
|
self
|
|
}
|
|
|
|
pub fn with_max_tuning_duration(mut self, ms: i32) -> Self {
|
|
self.tunable_op_max_tuning_duration_ms = ms;
|
|
self
|
|
}
|
|
|
|
pub fn build(self) -> ExecutionProviderDispatch {
|
|
self.into()
|
|
}
|
|
}
|
|
|
|
impl From<ROCmExecutionProvider> for ExecutionProviderDispatch {
|
|
fn from(value: ROCmExecutionProvider) -> Self {
|
|
ExecutionProviderDispatch::ROCm(value)
|
|
}
|
|
}
|
|
|
|
impl ExecutionProvider for ROCmExecutionProvider {
|
|
fn as_str(&self) -> &'static str {
|
|
"ROCmExecutionProvider"
|
|
}
|
|
|
|
#[allow(unused, unreachable_code)]
|
|
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
|
|
#[cfg(any(feature = "load-dynamic", feature = "rocm"))]
|
|
{
|
|
let rocm_options = ort_sys::OrtROCMProviderOptions {
|
|
device_id: self.device_id,
|
|
miopen_conv_exhaustive_search: self.miopen_conv_exhaustive_search.into(),
|
|
gpu_mem_limit: self.gpu_mem_limit as _,
|
|
arena_extend_strategy: match self.arena_extend_strategy {
|
|
ArenaExtendStrategy::NextPowerOfTwo => 0,
|
|
ArenaExtendStrategy::SameAsRequested => 1
|
|
},
|
|
do_copy_in_default_stream: self.do_copy_in_default_stream.into(),
|
|
has_user_compute_stream: self.user_compute_stream.is_some().into(),
|
|
user_compute_stream: self.user_compute_stream.unwrap_or_else(std::ptr::null_mut),
|
|
default_memory_arena_cfg: self.default_memory_arena_cfg.unwrap_or_else(std::ptr::null_mut),
|
|
tunable_op_enable: self.tunable_op_enable.into(),
|
|
tunable_op_tuning_enable: self.tunable_op_tuning_enable.into(),
|
|
tunable_op_max_tuning_duration_ms: self.tunable_op_max_tuning_duration_ms
|
|
};
|
|
return crate::error::status_to_result(
|
|
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_ROCM(session_builder.session_options_ptr, &rocm_options as *const _)]
|
|
)
|
|
.map_err(Error::ExecutionProvider);
|
|
}
|
|
|
|
Err(Error::ExecutionProviderNotRegistered(self.as_str()))
|
|
}
|
|
}
|