feat: manual device selection

This commit is contained in:
Carson M.
2026-03-02 18:46:29 -06:00
parent 771e1a5c4a
commit a08efe6147
5 changed files with 235 additions and 4 deletions

99
src/device.rs Normal file
View File

@@ -0,0 +1,99 @@
use core::{ffi::CStr, marker::PhantomData, ptr::NonNull};
use crate::{AsPointer, Error, Result, memory::DeviceType, ortsys};
pub struct Device<'e> {
ptr: NonNull<ort_sys::OrtEpDevice>,
hw_ptr: NonNull<ort_sys::OrtHardwareDevice>,
_env: PhantomData<&'e ()>
}
impl<'e> Device<'e> {
pub(crate) fn new(ptr: NonNull<ort_sys::OrtEpDevice>) -> Self {
Self {
ptr,
hw_ptr: NonNull::new(ortsys![unsafe EpDevice_Device(ptr.as_ptr())].cast_mut()).expect("invalid device"),
_env: PhantomData
}
}
/// Returns the [name of the EP](crate::ep::ExecutionProvider::name) this device belongs to.
pub fn ep(&self) -> Result<&'e str> {
let name = ortsys![unsafe EpDevice_EpName(self.ptr.as_ptr())];
unsafe { CStr::from_ptr(name) }.to_str().map_err(Error::from)
}
/// Returns the name of the EP vendor this device belongs to, e.g. `"Microsoft"` for DirectML devices.
///
/// For the *manufacturer* of the device, see [`Device::vendor`].
pub fn ep_vendor(&self) -> Result<&'e str> {
let vendor = ortsys![unsafe EpDevice_EpVendor(self.ptr.as_ptr())];
unsafe { CStr::from_ptr(vendor) }.to_str().map_err(Error::from)
}
/// Returns the [type](DeviceType) of the device - CPU, GPU, or NPU.
pub fn ty(&self) -> DeviceType {
match ortsys![unsafe HardwareDevice_Type(self.hw_ptr.as_ptr())] {
ort_sys::OrtHardwareDeviceType::OrtHardwareDeviceType_CPU => DeviceType::CPU,
ort_sys::OrtHardwareDeviceType::OrtHardwareDeviceType_GPU => DeviceType::GPU,
ort_sys::OrtHardwareDeviceType::OrtHardwareDeviceType_NPU => DeviceType::NPU
}
}
/// Returns the device ID.
///
/// Appears to be arbitrary and is **not** the same as the device *index* (but is unique per device).
pub fn id(&self) -> u32 {
ortsys![unsafe HardwareDevice_DeviceId(self.hw_ptr.as_ptr())]
}
/// Returns the name of the manufacturer of the device.
pub fn vendor(&self) -> Result<&'e str> {
let vendor = ortsys![unsafe HardwareDevice_Vendor(self.hw_ptr.as_ptr())];
unsafe { CStr::from_ptr(vendor) }.to_str().map_err(Error::from)
}
}
impl AsPointer for Device<'_> {
type Sys = ort_sys::OrtEpDevice;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
#[cfg(test)]
mod tests {
use crate::{Result, memory::DeviceType, session::Session};
#[test]
fn test_device_meta() -> Result<()> {
let env = crate::environment::current()?;
// CPUExecutionProvider should always be first (for now anyways...)
let device = env.devices().next().expect("");
assert!(matches!(device.ep(), Ok("CPUExecutionProvider")));
assert!(matches!(device.ep_vendor(), Ok("Microsoft")));
assert_eq!(device.ty(), DeviceType::CPU);
Ok(())
}
#[test]
fn test_session_devices() -> Result<()> {
let env = crate::environment::current()?;
let _session1 = Session::builder()?
.with_devices(env.devices().next(), None)?
.commit_from_file("tests/data/upsample.onnx")?;
let options = vec![
("CPUExecutionProvider.use_arena".to_string(), "1".to_string()),
("XnnpackExecutionProvider.num_threads".to_string(), "4".to_string()),
];
let _session2 = Session::builder()?
.with_devices(env.devices().next(), Some(&options))?
.commit_from_file("tests/data/upsample.onnx")?;
Ok(())
}
}

View File

@@ -150,6 +150,37 @@ impl Environment {
Ok(ExecutionProviderLibrary::new(name, self))
}
/// Returns an iterator over all automatically discovered [hardware device](crate::device::Device)s.
///
/// ```
/// # use ort::environment::Environment;
/// # fn main() -> ort::Result<()> {
/// let env = Environment::current()?;
/// for device in env.devices() {
/// println!(
/// "{id} ({vendor} {ty:?} - {ep})",
/// id = device.id(),
/// vendor = device.vendor()?,
/// ty = device.ty(),
/// ep = device.ep()?
/// );
/// }
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn devices(&self) -> impl DoubleEndedIterator<Item = crate::device::Device<'_>> + '_ {
let mut ptrs = ptr::dangling();
let mut len = 0;
// returns an error in minimal build because its unsupported. ignore & return empty iterator in that case
let _ = ortsys![@ort: unsafe GetEpDevices(self.ptr().cast_mut(), &mut ptrs, &mut len) as Result];
unsafe { core::slice::from_raw_parts(ptrs, len) }
.iter()
.filter_map(|c| NonNull::new(c.cast_mut()))
.map(crate::device::Device::new)
}
#[inline]
pub(crate) fn has_global_threadpool(&self) -> bool {
self.has_global_threadpool

View File

@@ -30,6 +30,9 @@ pub(crate) mod private;
pub mod compiler;
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub mod device;
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub mod editor;
pub mod environment;
pub mod ep;

View File

@@ -347,14 +347,14 @@ impl SessionBuilder {
}
}
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy) based
/// on available devices.
/// Automatically select & register an execution provider according to the given [`policy`](AutoDevicePolicy), based
/// on available hardware devices.
///
/// For finer control over device selection, and to configure EP options, see [`SessionBuilder::with_devices`].
///
/// ```no_run
/// # use ort::session::{Session, builder::AutoDevicePolicy};
/// # fn main() -> ort::Result<()> {
/// use std::sync::Arc;
///
/// let mut session = Session::builder()?
/// // moar power!!1!
/// .with_auto_device(AutoDevicePolicy::MaxPerformance)?
@@ -370,6 +370,92 @@ impl SessionBuilder {
Err(e) => Err(e.with_recover(self))
}
}
/// Use a list of hardware devices automatically discovered by the environment via
/// [`Environment::devices`](crate::environment::Environment::devices).
///
/// `options` can be specified to add EP options. Each EP option must be prefixed with the name of the EP
/// (obtained by [`Device::ep`](crate::device::Device::ep)) followed by `.`.
///
/// ```
/// # use ort::{environment::Environment, session::Session, memory::DeviceType};
/// # fn main() -> ort::Result<()> {
/// let env = Environment::current()?;
///
/// let options = vec![
/// ("CPUExecutionProvider.use_arena".to_string(), "1".to_string()),
/// ("XnnpackExecutionProvider.num_threads".to_string(), "4".to_string()),
/// ];
/// let mut session = Session::builder()?
/// .with_devices(env.devices().filter(|dev| dev.ty() == DeviceType::CPU), Some(&options))?
/// .commit_from_file("tests/data/upsample.onnx")?;
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn with_devices<'e>(mut self, devices: impl IntoIterator<Item = crate::device::Device<'e>>, options: Option<&[(String, String)]>) -> BuilderResult {
use smallvec::SmallVec;
use crate::util::{MiniMap, with_cstr_ptr_array};
#[derive(Default)]
struct DeviceGroup<'o> {
device_ptrs: SmallVec<[*const ort_sys::OrtEpDevice; 2]>,
option_keys: Vec<&'o str>,
option_values: Vec<&'o str>
}
let existing_devices: SmallVec<[_; 4]> = self.environment.devices().map(|x| x.ptr()).collect();
let mut device_groups = MiniMap::<&str, DeviceGroup<'_>>::new();
let mut group_prefix = [0u8; 128];
for device in devices {
let ptr = device.ptr();
if !existing_devices.contains(&ptr) {
return Err(Error::new("device comes from different environment").with_recover(self));
}
let group = device.ep().expect("invalid utf-8");
group_prefix[..group.len()].copy_from_slice(group.as_bytes());
group_prefix[group.len()] = b'.';
let group_prefix = unsafe { core::str::from_utf8_unchecked(core::slice::from_raw_parts(group_prefix.as_ptr(), group.len() + 1)) };
let group = device_groups.get_or_insert_with(group, DeviceGroup::default);
group.device_ptrs.push(ptr);
if let Some(options) = options {
for (key, value) in options.iter() {
if let Some(real_key) = key.strip_prefix(group_prefix) {
group.option_keys.push(real_key);
group.option_values.push(value.as_str());
}
}
}
}
for (_, group) in device_groups.iter() {
let ptr = self.ptr_mut();
let env_ptr = self.environment.ptr().cast_mut();
if let Err(e) = with_cstr_ptr_array(&group.option_keys, &|option_keys| {
with_cstr_ptr_array(&group.option_values, &|option_values| {
ortsys![unsafe SessionOptionsAppendExecutionProvider_V2(
ptr,
env_ptr,
group.device_ptrs.as_ptr(),
group.device_ptrs.len(),
option_keys.as_ptr(),
option_values.as_ptr(),
option_keys.len()
)?];
Ok(())
})
}) {
return Err(e.with_recover(self));
}
}
Ok(self)
}
}
/// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially

View File

@@ -37,6 +37,18 @@ impl<K: Eq, V> MiniMap<K, V> {
self.values.iter_mut().find(|(k, _)| key.eq(k.borrow())).map(|(_, v)| v)
}
pub fn get_or_insert_with(&mut self, key: K, f: impl FnOnce() -> V) -> &mut V {
let idx = match self.values.iter_mut().position(|(k, _)| key.eq(k.borrow())) {
Some(v) => v,
None => {
let idx = self.values.len();
self.values.push((key, f()));
idx
}
};
&mut self.values[idx].1
}
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
match self.get_mut(&key) {
Some(v) => Some(mem::replace(v, value)),