Files
ort/src/execution_providers.rs
Carson M 7b3acaf95b 1.13.0
2022-11-26 15:16:30 -06:00

180 lines
5.9 KiB
Rust

#![allow(unused_imports)]
use std::{collections::HashMap, ffi::CString, os::raw::c_char};
use super::{error::status_to_result, ortsys, sys};
extern "C" {
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_CPU(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
#[cfg(feature = "acl")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
}
#[derive(Debug, Clone)]
pub struct ExecutionProvider {
provider: String,
options: HashMap<String, String>
}
macro_rules! ep_providers {
($($fn_name:ident = $name:expr),*) => {
$(
/// Creates a new `
#[doc = $name]
#[doc = "` configuration object."]
pub fn $fn_name() -> Self {
Self::new($name)
}
)*
}
}
macro_rules! ep_if_available {
($($fn_name:ident($original:ident): $name:expr),*) => {
$(
/// Creates a new
#[doc = $name]
#[doc = " execution provider if available, otherwise falling back to CPU."]
pub fn $fn_name() -> Self {
let o = Self::$original();
if o.is_available() { o } else { Self::cpu() }
}
)*
}
}
macro_rules! ep_options {
($(
$(#[$meta:meta])*
pub fn $fn_name:ident($opt_type:ty) = $option_name:ident;
)*) => {
$(
$(#[$meta])*
pub fn $fn_name(mut self, v: $opt_type) -> Self {
self = self.with(stringify!($option_name), v.to_string());
self
}
)*
}
}
impl ExecutionProvider {
pub fn new(provider: impl Into<String>) -> Self {
Self {
provider: provider.into(),
options: HashMap::new()
}
}
ep_providers! {
acl = "AclExecutionProvider",
cuda = "CUDAExecutionProvider",
tensorrt = "TensorRTExecutionProvider",
cpu = "CPUExecutionProvider"
}
pub fn is_available(&self) -> bool {
let mut providers: *mut *mut c_char = std::ptr::null_mut();
let mut num_providers = 0;
if status_to_result(ortsys![unsafe GetAvailableProviders(&mut providers, &mut num_providers)]).is_err() {
return false;
}
for i in 0..num_providers {
let avail = unsafe { std::ffi::CStr::from_ptr(*providers.offset(i as isize)) }
.to_string_lossy()
.into_owned();
if self.provider == avail {
return true;
}
}
false
}
ep_if_available! {
tensorrt_if_available(tensorrt): "TensorRT",
cuda_if_available(cuda): "CUDA",
acl_if_available(acl): "ACL"
}
/// Configure this execution provider with the given option.
pub fn with(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
self.options.insert(k.into(), v.into());
self
}
ep_options! {
/// Whether or not to use CPU arena allocator.
pub fn with_use_arena(bool) = use_arena;
}
}
pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, execution_providers: impl AsRef<[ExecutionProvider]>) {
for ep in execution_providers.as_ref() {
let init_args = ep.options.clone();
match ep.provider.as_str() {
#[cfg(feature = "acl")]
"AclExecutionProvider" => {
let use_arena = init_args.get("use_arena").map(|s| s.parse::<bool>().unwrap_or(false)).unwrap_or(false);
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_ACL(options, use_arena.into()) };
if status_to_result(status).is_ok() {
return; // EP found
}
}
"CPUExecutionProvider" => {
let use_arena = init_args.get("use_arena").map(|s| s.parse::<bool>().unwrap_or(false)).unwrap_or(false);
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_CPU(options, use_arena.into()) };
if status_to_result(status).is_ok() {
return; // EP found
}
}
#[cfg(feature = "cuda")]
"CUDAExecutionProvider" => {
let mut cuda_options: *mut sys::OrtCUDAProviderOptionsV2 = std::ptr::null_mut();
if status_to_result(ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)]).is_err() {
continue; // next EP
}
let keys: Vec<CString> = init_args.keys().map(|k| CString::new(k.as_str()).unwrap()).collect();
let values: Vec<CString> = init_args.values().map(|v| CString::new(v.as_str()).unwrap()).collect();
assert_eq!(keys.len(), values.len()); // sanity check
let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect();
let value_ptrs: Vec<*const c_char> = values.iter().map(|v| v.as_ptr()).collect();
let status = ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), keys.len())];
if status_to_result(status).is_err() {
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
continue; // next EP
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(options, cuda_options)];
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
if status_to_result(status).is_ok() {
return; // EP found
}
}
#[cfg(feature = "tensorrt")]
"TensorRTExecutionProvider" => {
let mut tensorrt_options: *mut sys::OrtTensorRTProviderOptionsV2 = std::ptr::null_mut();
if status_to_result(ortsys![unsafe CreateTensorRTProviderOptions(&mut tensorrt_options)]).is_err() {
continue; // next EP
}
let keys: Vec<CString> = init_args.keys().map(|k| CString::new(k.as_str()).unwrap()).collect();
let values: Vec<CString> = init_args.values().map(|v| CString::new(v.as_str()).unwrap()).collect();
assert_eq!(keys.len(), values.len()); // sanity check
let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect();
let value_ptrs: Vec<*const c_char> = values.iter().map(|v| v.as_ptr()).collect();
let status = ortsys![unsafe UpdateTensorRTProviderOptions(tensorrt_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), keys.len())];
if status_to_result(status).is_err() {
ortsys![unsafe ReleaseTensorRTProviderOptions(tensorrt_options)];
continue; // next EP
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_TensorRT_V2(options, tensorrt_options)];
ortsys![unsafe ReleaseTensorRTProviderOptions(tensorrt_options)];
if status_to_result(status).is_ok() {
return; // EP found
}
}
_ => {}
};
}
}