mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: manual device selection
This commit is contained in:
99
src/device.rs
Normal file
99
src/device.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use core::{ffi::CStr, marker::PhantomData, ptr::NonNull};
|
||||
|
||||
use crate::{AsPointer, Error, Result, memory::DeviceType, ortsys};
|
||||
|
||||
pub struct Device<'e> {
|
||||
ptr: NonNull<ort_sys::OrtEpDevice>,
|
||||
hw_ptr: NonNull<ort_sys::OrtHardwareDevice>,
|
||||
_env: PhantomData<&'e ()>
|
||||
}
|
||||
|
||||
impl<'e> Device<'e> {
|
||||
pub(crate) fn new(ptr: NonNull<ort_sys::OrtEpDevice>) -> Self {
|
||||
Self {
|
||||
ptr,
|
||||
hw_ptr: NonNull::new(ortsys![unsafe EpDevice_Device(ptr.as_ptr())].cast_mut()).expect("invalid device"),
|
||||
_env: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [name of the EP](crate::ep::ExecutionProvider::name) this device belongs to.
|
||||
pub fn ep(&self) -> Result<&'e str> {
|
||||
let name = ortsys![unsafe EpDevice_EpName(self.ptr.as_ptr())];
|
||||
unsafe { CStr::from_ptr(name) }.to_str().map_err(Error::from)
|
||||
}
|
||||
|
||||
/// Returns the name of the EP vendor this device belongs to, e.g. `"Microsoft"` for DirectML devices.
|
||||
///
|
||||
/// For the *manufacturer* of the device, see [`Device::vendor`].
|
||||
pub fn ep_vendor(&self) -> Result<&'e str> {
|
||||
let vendor = ortsys![unsafe EpDevice_EpVendor(self.ptr.as_ptr())];
|
||||
unsafe { CStr::from_ptr(vendor) }.to_str().map_err(Error::from)
|
||||
}
|
||||
|
||||
/// Returns the [type](DeviceType) of the device - CPU, GPU, or NPU.
|
||||
pub fn ty(&self) -> DeviceType {
|
||||
match ortsys![unsafe HardwareDevice_Type(self.hw_ptr.as_ptr())] {
|
||||
ort_sys::OrtHardwareDeviceType::OrtHardwareDeviceType_CPU => DeviceType::CPU,
|
||||
ort_sys::OrtHardwareDeviceType::OrtHardwareDeviceType_GPU => DeviceType::GPU,
|
||||
ort_sys::OrtHardwareDeviceType::OrtHardwareDeviceType_NPU => DeviceType::NPU
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the device ID.
|
||||
///
|
||||
/// Appears to be arbitrary and is **not** the same as the device *index* (but is unique per device).
|
||||
pub fn id(&self) -> u32 {
|
||||
ortsys![unsafe HardwareDevice_DeviceId(self.hw_ptr.as_ptr())]
|
||||
}
|
||||
|
||||
/// Returns the name of the manufacturer of the device.
|
||||
pub fn vendor(&self) -> Result<&'e str> {
|
||||
let vendor = ortsys![unsafe HardwareDevice_Vendor(self.hw_ptr.as_ptr())];
|
||||
unsafe { CStr::from_ptr(vendor) }.to_str().map_err(Error::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsPointer for Device<'_> {
|
||||
type Sys = ort_sys::OrtEpDevice;
|
||||
|
||||
fn ptr(&self) -> *const Self::Sys {
|
||||
self.ptr.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{Result, memory::DeviceType, session::Session};
|
||||
|
||||
#[test]
|
||||
fn test_device_meta() -> Result<()> {
|
||||
let env = crate::environment::current()?;
|
||||
// CPUExecutionProvider should always be first (for now anyways...)
|
||||
let device = env.devices().next().expect("");
|
||||
assert!(matches!(device.ep(), Ok("CPUExecutionProvider")));
|
||||
assert!(matches!(device.ep_vendor(), Ok("Microsoft")));
|
||||
assert_eq!(device.ty(), DeviceType::CPU);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_devices() -> Result<()> {
|
||||
let env = crate::environment::current()?;
|
||||
|
||||
let _session1 = Session::builder()?
|
||||
.with_devices(env.devices().next(), None)?
|
||||
.commit_from_file("tests/data/upsample.onnx")?;
|
||||
|
||||
let options = vec![
|
||||
("CPUExecutionProvider.use_arena".to_string(), "1".to_string()),
|
||||
("XnnpackExecutionProvider.num_threads".to_string(), "4".to_string()),
|
||||
];
|
||||
let _session2 = Session::builder()?
|
||||
.with_devices(env.devices().next(), Some(&options))?
|
||||
.commit_from_file("tests/data/upsample.onnx")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -150,6 +150,37 @@ impl Environment {
|
||||
Ok(ExecutionProviderLibrary::new(name, self))
|
||||
}
|
||||
|
||||
/// Returns an iterator over all automatically discovered [hardware device](crate::device::Device)s.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::environment::Environment;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let env = Environment::current()?;
|
||||
/// for device in env.devices() {
|
||||
/// println!(
|
||||
/// "{id} ({vendor} {ty:?} - {ep})",
|
||||
/// id = device.id(),
|
||||
/// vendor = device.vendor()?,
|
||||
/// ty = device.ty(),
|
||||
/// ep = device.ep()?
|
||||
/// );
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg(feature = "api-22")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||
pub fn devices(&self) -> impl DoubleEndedIterator<Item = crate::device::Device<'_>> + '_ {
|
||||
let mut ptrs = ptr::dangling();
|
||||
let mut len = 0;
|
||||
// returns an error in minimal build because its unsupported. ignore & return empty iterator in that case
|
||||
let _ = ortsys![@ort: unsafe GetEpDevices(self.ptr().cast_mut(), &mut ptrs, &mut len) as Result];
|
||||
unsafe { core::slice::from_raw_parts(ptrs, len) }
|
||||
.iter()
|
||||
.filter_map(|c| NonNull::new(c.cast_mut()))
|
||||
.map(crate::device::Device::new)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn has_global_threadpool(&self) -> bool {
|
||||
self.has_global_threadpool
|
||||
|
||||
@@ -30,6 +30,9 @@ pub(crate) mod private;
|
||||
pub mod compiler;
|
||||
#[cfg(feature = "api-22")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||
pub mod device;
|
||||
#[cfg(feature = "api-22")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||
pub mod editor;
|
||||
pub mod environment;
|
||||
pub mod ep;
|
||||
|
||||
@@ -347,14 +347,14 @@ impl SessionBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy) based
|
||||
/// on available devices.
|
||||
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy), based
|
||||
/// on available hardware devices.
|
||||
///
|
||||
/// For finer control over device selection, and to configure EP options, see [`SessionBuilder::with_devices`].
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ort::session::{Session, builder::AutoDevicePolicy};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// let mut session = Session::builder()?
|
||||
/// // moar power!!1!
|
||||
/// .with_auto_device(AutoDevicePolicy::MaxPerformance)?
|
||||
@@ -370,6 +370,92 @@ impl SessionBuilder {
|
||||
Err(e) => Err(e.with_recover(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Use a list of hardware devices automatically discovered by the environment via
|
||||
/// [`Environment::devices`](crate::environment::Environment::devices).
|
||||
///
|
||||
/// `options` can be specified to add EP options. Each EP option must be prefixed with the name of the EP
|
||||
/// (obtained by [`Device::ep`](crate::device::Device::ep)) followed by `.`.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{environment::Environment, session::Session, memory::DeviceType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let env = Environment::current()?;
|
||||
///
|
||||
/// let options = vec![
|
||||
/// ("CPUExecutionProvider.use_arena".to_string(), "1".to_string()),
|
||||
/// ("XnnpackExecutionProvider.num_threads".to_string(), "4".to_string()),
|
||||
/// ];
|
||||
/// let mut session = Session::builder()?
|
||||
/// .with_devices(env.devices().filter(|dev| dev.ty() == DeviceType::CPU), Some(&options))?
|
||||
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg(feature = "api-22")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||
pub fn with_devices<'e>(mut self, devices: impl IntoIterator<Item = crate::device::Device<'e>>, options: Option<&[(String, String)]>) -> BuilderResult {
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use crate::util::{MiniMap, with_cstr_ptr_array};
|
||||
|
||||
#[derive(Default)]
|
||||
struct DeviceGroup<'o> {
|
||||
device_ptrs: SmallVec<[*const ort_sys::OrtEpDevice; 2]>,
|
||||
option_keys: Vec<&'o str>,
|
||||
option_values: Vec<&'o str>
|
||||
}
|
||||
|
||||
let existing_devices: SmallVec<[_; 4]> = self.environment.devices().map(|x| x.ptr()).collect();
|
||||
let mut device_groups = MiniMap::<&str, DeviceGroup<'_>>::new();
|
||||
|
||||
let mut group_prefix = [0u8; 128];
|
||||
for device in devices {
|
||||
let ptr = device.ptr();
|
||||
if !existing_devices.contains(&ptr) {
|
||||
return Err(Error::new("device comes from different environment").with_recover(self));
|
||||
}
|
||||
|
||||
let group = device.ep().expect("invalid utf-8");
|
||||
group_prefix[..group.len()].copy_from_slice(group.as_bytes());
|
||||
group_prefix[group.len()] = b'.';
|
||||
let group_prefix = unsafe { core::str::from_utf8_unchecked(core::slice::from_raw_parts(group_prefix.as_ptr(), group.len() + 1)) };
|
||||
|
||||
let group = device_groups.get_or_insert_with(group, DeviceGroup::default);
|
||||
group.device_ptrs.push(ptr);
|
||||
if let Some(options) = options {
|
||||
for (key, value) in options.iter() {
|
||||
if let Some(real_key) = key.strip_prefix(group_prefix) {
|
||||
group.option_keys.push(real_key);
|
||||
group.option_values.push(value.as_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, group) in device_groups.iter() {
|
||||
let ptr = self.ptr_mut();
|
||||
let env_ptr = self.environment.ptr().cast_mut();
|
||||
if let Err(e) = with_cstr_ptr_array(&group.option_keys, &|option_keys| {
|
||||
with_cstr_ptr_array(&group.option_values, &|option_values| {
|
||||
ortsys![unsafe SessionOptionsAppendExecutionProvider_V2(
|
||||
ptr,
|
||||
env_ptr,
|
||||
group.device_ptrs.as_ptr(),
|
||||
group.device_ptrs.len(),
|
||||
option_keys.as_ptr(),
|
||||
option_values.as_ptr(),
|
||||
option_keys.len()
|
||||
)?];
|
||||
Ok(())
|
||||
})
|
||||
}) {
|
||||
return Err(e.with_recover(self));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially
|
||||
|
||||
@@ -37,6 +37,18 @@ impl<K: Eq, V> MiniMap<K, V> {
|
||||
self.values.iter_mut().find(|(k, _)| key.eq(k.borrow())).map(|(_, v)| v)
|
||||
}
|
||||
|
||||
pub fn get_or_insert_with(&mut self, key: K, f: impl FnOnce() -> V) -> &mut V {
|
||||
let idx = match self.values.iter_mut().position(|(k, _)| key.eq(k.borrow())) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
let idx = self.values.len();
|
||||
self.values.push((key, f()));
|
||||
idx
|
||||
}
|
||||
};
|
||||
&mut self.values[idx].1
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
match self.get_mut(&key) {
|
||||
Some(v) => Some(mem::replace(v, value)),
|
||||
|
||||
Reference in New Issue
Block a user