refactor: shorten execution_providers

frankly, working on documentation has made me tired of typing out `execution_providers` and `ExecutionProvider` all the time
This commit is contained in:
Carson M.
2025-11-14 23:57:09 -06:00
parent 47e5667d6e
commit 3b408b1b44
43 changed files with 508 additions and 435 deletions

View File

@@ -12,7 +12,7 @@ To configure an `Environment`, you instead use the `ort::init` function, which r
```rust
ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.with_execution_providers([ep::CUDA::default().build()])
.commit();
```
@@ -156,7 +156,7 @@ let l = outputs["latents"].try_extract_array::<f32>()?;
```
## Execution providers
Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/execution_providers/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information.
Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/ep/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information.
```diff
-// v1.x
@@ -165,7 +165,7 @@ Execution provider structs with public fields have been replaced with builder pa
-}))?;
+// v2
+builder = builder.with_execution_providers([
+ DirectMLExecutionProvider::default()
+ ep::DirectML::default()
+ .with_device_id(1)
+ .build()
+])?;

View File

@@ -52,11 +52,11 @@ ONNX Runtime must be compiled from source with support for each execution provid
In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session:
```rust
use ort::{execution_providers::CUDAExecutionProvider, session::Session};
use ort::{ep::CUDA, session::Session};
fn main() -> anyhow::Result<()> {
let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_execution_providers([CUDA::default().build()])?
.commit_from_file("model.onnx")?;
Ok(())
@@ -66,21 +66,18 @@ fn main() -> anyhow::Result<()> {
You can, of course, specify multiple execution providers. `ort` will register all EPs specified, in order. If an EP does not support a certain operator in a graph, it will fall back to the next successfully registered EP, or to the CPU if all else fails.
```rust
use ort::{
execution_providers::{CoreMLExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, TensorRTExecutionProvider},
session::Session
};
use ort::{ep, session::Session};
fn main() -> anyhow::Result<()> {
let session = Session::builder()?
.with_execution_providers([
// Prefer TensorRT over CUDA.
TensorRTExecutionProvider::default().build(),
CUDAExecutionProvider::default().build(),
ep::TensorRT::default().build(),
ep::CUDA::default().build(),
// Use DirectML on Windows if NVIDIA EPs are not available
DirectMLExecutionProvider::default().build(),
ep::DirectML::default().build(),
// Or use ANE on Apple platforms
CoreMLExecutionProvider::default().build()
ep::CoreML::default().build()
])?
.commit_from_file("model.onnx")?;
@@ -89,18 +86,18 @@ fn main() -> anyhow::Result<()> {
```
## Configuring EPs
EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/execution_providers/index.html#reexports) for the EP structs for more information on which options are supported and what they do.
EPs have configuration options to control behavior or increase performance. Each execution provider struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/ep/index.html#reexports) for the EP structs for more information on which options are supported and what they do.
```rust
use ort::{execution_providers::{coreml::CoreMLComputeUnits, CoreMLExecutionProvider}, session::Session};
use ort::{ep, session::Session};
fn main() -> anyhow::Result<()> {
let session = Session::builder()?
.with_execution_providers([
CoreMLExecutionProvider::default()
ep::CoreML::default()
// this model uses control flow operators, so enable CoreML on subgraphs too
.with_subgraphs()
.with_compute_units(CoreMLComputeUnits::CPUAndNeuralEngine)
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.build()
])?
.commit_from_file("model.onnx")?;
@@ -114,12 +111,12 @@ fn main() -> anyhow::Result<()> {
You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`:
```rust
use ort::{execution_providers::CoreMLExecutionProvider, session::Session};
use ort::{ep, session::Session};
fn main() -> anyhow::Result<()> {
let session = Session::builder()?
.with_execution_providers([
CUDAExecutionProvider::default().build().error_on_failure()
ep::CUDA::default().build().error_on_failure()
])?
.commit_from_file("model.onnx")?;
@@ -131,14 +128,14 @@ If you require more complex error handling, you can also manually register execu
```rust
use ort::{
execution_providers::{CUDAExecutionProvider, ExecutionProvider},
ep::{self, ExecutionProvider},
session::Session
};
fn main() -> anyhow::Result<()> {
let builder = Session::builder()?;
let cuda = CUDAExecutionProvider::default();
let cuda = ep::CUDA::default();
if cuda.register(&builder).is_err() {
eprintln!("Failed to register CUDA!");
std::process::exit(1);
@@ -154,14 +151,14 @@ You can also check whether ONNX Runtime is even compiled with support for the ex
```rust
use ort::{
execution_providers::{CoreMLExecutionProvider, ExecutionProvider},
ep::{self, ExecutionProvider},
session::Session
};
fn main() -> anyhow::Result<()> {
let builder = Session::builder()?;
let coreml = CoreMLExecutionProvider::default();
let coreml = ep::CoreML::default();
if !coreml.is_available() {
eprintln!("Please compile ONNX Runtime with CoreML!");
std::process::exit(1);
@@ -180,11 +177,11 @@ fn main() -> anyhow::Result<()> {
You can configure `ort` to attempt to register a list of execution providers for all sessions created in an environment.
```rust
use ort::{execution_providers::CUDAExecutionProvider, session::Session};
use ort::{ep, session::Session};
fn main() -> anyhow::Result<()> {
ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.with_execution_providers([ep::CUDA::default().build()])
.commit();
let session = Session::builder()?.commit_from_file("model.onnx")?;

View File

@@ -80,10 +80,10 @@ Here is a more complete example of the I/O binding API in a scenario where I/O p
```rs
let mut text_encoder = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_execution_providers([ep::CUDA::default().build()])?
.commit_from_file("text_encoder.onnx")?;
let mut unet = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_execution_providers([ep::CUDA::default().build()])?
.commit_from_file("unet.onnx")?;
let text_condition = {

View File

@@ -1,5 +1,5 @@
#[allow(unused)]
use ort::execution_providers::*;
use ort::ep::*;
pub fn init() -> ort::Result<()> {
#[cfg(feature = "backend-candle")]
@@ -11,41 +11,41 @@ pub fn init() -> ort::Result<()> {
ort::init()
.with_execution_providers([
#[cfg(feature = "tensorrt")]
TensorRTExecutionProvider::default().build(),
TensorRT::default().build(),
#[cfg(feature = "cuda")]
CUDAExecutionProvider::default().build(),
CUDA::default().build(),
#[cfg(feature = "onednn")]
OneDNNExecutionProvider::default().build(),
OneDNN::default().build(),
#[cfg(feature = "acl")]
ACLExecutionProvider::default().build(),
ACL::default().build(),
#[cfg(feature = "openvino")]
OpenVINOExecutionProvider::default().build(),
OpenVINO::default().build(),
#[cfg(feature = "coreml")]
CoreMLExecutionProvider::default().build(),
CoreML::default().build(),
#[cfg(feature = "rocm")]
ROCmExecutionProvider::default().build(),
ROCm::default().build(),
#[cfg(feature = "cann")]
CANNExecutionProvider::default().build(),
CANN::default().build(),
#[cfg(feature = "directml")]
DirectMLExecutionProvider::default().build(),
DirectML::default().build(),
#[cfg(feature = "tvm")]
TVMExecutionProvider::default().build(),
TVM::default().build(),
#[cfg(feature = "nnapi")]
NNAPIExecutionProvider::default().build(),
NNAPI::default().build(),
#[cfg(feature = "qnn")]
QNNExecutionProvider::default().build(),
QNN::default().build(),
#[cfg(feature = "xnnpack")]
XNNPACKExecutionProvider::default().build(),
XNNPACK::default().build(),
#[cfg(feature = "armnn")]
ArmNNExecutionProvider::default().build(),
ArmNN::default().build(),
#[cfg(feature = "migraphx")]
MIGraphXExecutionProvider::default().build(),
MIGraphX::default().build(),
#[cfg(feature = "vitis")]
VitisAIExecutionProvider::default().build(),
VitisAI::default().build(),
#[cfg(feature = "rknpu")]
RKNPUExecutionProvider::default().build(),
RKNPU::default().build(),
#[cfg(feature = "webgpu")]
WebGPUExecutionProvider::default().build()
WebGPU::default().build()
])
.commit();

View File

@@ -4,7 +4,7 @@ use cudarc::driver::{CudaDevice, DevicePtr};
use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType};
use ndarray::Array;
use ort::{
execution_providers::CUDAExecutionProvider,
ep,
memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType},
session::Session,
tensor::Shape,
@@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
#[rustfmt::skip]
ort::init()
.with_execution_providers([
CUDAExecutionProvider::default()
ep::CUDA::default()
.build()
// exit the program with an error if the CUDA EP fails to register
.error_on_failure()

View File

@@ -6,7 +6,7 @@ use std::{
use kdam::BarExt;
use ort::{
execution_providers::CUDAExecutionProvider,
ep,
memory::Allocator,
session::{Session, builder::SessionBuilder},
training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments},
@@ -56,7 +56,7 @@ fn main() -> ort::Result<()> {
common::init()?;
let trainer = Trainer::new_from_artifacts(
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?,
Allocator::default(),
"tools/train-data/mini-clm",
None

View File

@@ -6,7 +6,7 @@ use std::{
use kdam::BarExt;
use ort::{
execution_providers::CUDAExecutionProvider,
ep,
memory::Allocator,
session::{Session, builder::SessionBuilder},
training::{Checkpoint, Trainer},
@@ -38,7 +38,7 @@ fn main() -> ort::Result<()> {
let _ = kdam::term::hide_cursor();
let trainer = Trainer::new(
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?,
Allocator::default(),
Checkpoint::load("tools/train-data/mini-clm/checkpoint")?,
"tools/train-data/mini-clm/training_model.onnx",

View File

@@ -57,8 +57,8 @@ pub extern "C" fn detect_objects(ptr: *const u8, width: u32, height: u32) {
let use_webgpu = true; // TODO: Make `use_webgpu` a parameter of `detect_objects`? Or say in README to change it here.
if use_webgpu {
use ort::execution_providers::ExecutionProvider;
let ep = ort::execution_providers::WebGPUExecutionProvider::default();
use ort::ep::ExecutionProvider;
let ep = ort::ep::WebGPUExecutionProvider::default();
if ep.is_available().expect("Cannot check for availability of WebGPU ep.") {
ep.register(&mut builder).expect("Cannot register WebGPU ep.");
} else {

View File

@@ -63,14 +63,14 @@ impl Adapter {
/// ```
/// # use ort::{
/// # adapter::Adapter,
/// # execution_providers::CUDAExecutionProvider,
/// # ep,
/// # memory::DeviceType,
/// # session::{run_options::RunOptions, Session},
/// # value::Tensor
/// # };
/// # fn main() -> ort::Result<()> {
/// let mut model = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .with_execution_providers([ep::CUDA::default().build()])?
/// .commit_from_file("tests/data/lora_model.onnx")?;
///
/// let allocator = model.allocator();
@@ -108,14 +108,14 @@ impl Adapter {
/// ```
/// # use ort::{
/// # adapter::Adapter,
/// # execution_providers::CUDAExecutionProvider,
/// # ep,
/// # memory::DeviceType,
/// # session::{run_options::RunOptions, Session},
/// # value::Tensor
/// # };
/// # fn main() -> ort::Result<()> {
/// let mut model = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .with_execution_providers([ep::CUDA::default().build()])?
/// .commit_from_file("tests/data/lora_model.onnx")?;
///
/// let bytes = std::fs::read("tests/data/adapter.orl").unwrap();

View File

@@ -2,9 +2,9 @@
//!
//! Environments can be configured via [`ort::init`](init):
//! ```
//! # use ort::execution_providers::CUDAExecutionProvider;
//! # use ort::ep;
//! # fn main() -> ort::Result<()> {
//! ort::init().with_execution_providers([CUDAExecutionProvider::default().build()]).commit();
//! ort::init().with_execution_providers([ep::CUDA::default().build()]).commit();
//! # Ok(())
//! # }
//! ```
@@ -25,8 +25,8 @@ use smallvec::SmallVec;
use crate::{
AsPointer,
ep::ExecutionProviderDispatch,
error::Result,
execution_providers::ExecutionProviderDispatch,
logging::{LogLevel, LoggerFunction},
ortsys,
util::{Mutex, OnceLock, STACK_EXECUTION_PROVIDERS, with_cstr}
@@ -430,9 +430,9 @@ impl EnvironmentBuilder {
/// Creates an ONNX Runtime environment.
///
/// ```
/// # use ort::execution_providers::CUDAExecutionProvider;
/// # use ort::ep;
/// # fn main() -> ort::Result<()> {
/// ort::init().with_execution_providers([CUDAExecutionProvider::default().build()]).commit();
/// ort::init().with_execution_providers([ep::CUDA::default().build()]).commit();
/// # Ok(())
/// # }
/// ```
@@ -456,11 +456,11 @@ pub fn init() -> EnvironmentBuilder {
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
///
/// ```no_run
/// # use ort::execution_providers::CUDAExecutionProvider;
/// # use ort::ep;
/// # fn main() -> ort::Result<()> {
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
/// ort::init_from(lib_path.join("onnxruntime.dll"))?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
/// .with_execution_providers([ep::CUDA::default().build()])
/// .commit();
/// # Ok(())
/// # }

View File

@@ -4,20 +4,20 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// [Arm Compute Library execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/ACL-ExecutionProvider.html)
/// for ARM platforms.
#[derive(Debug, Default, Clone)]
pub struct ACLExecutionProvider {
pub struct ACL {
fast_math: bool
}
super::impl_ep!(ACLExecutionProvider);
super::impl_ep!(ACL);
impl ACLExecutionProvider {
impl ACL {
/// Enable/disable ACL's fast math mode. Enabling can improve performance at the cost of some accuracy for
/// `MatMul`/`Conv` nodes.
///
/// ```
/// # use ort::{execution_providers::acl::ACLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ACLExecutionProvider::default().with_fast_math(true).build();
/// let ep = ep::ACL::default().with_fast_math(true).build();
/// # Ok(())
/// # }
/// ```
@@ -28,7 +28,7 @@ impl ACLExecutionProvider {
}
}
impl ExecutionProvider for ACLExecutionProvider {
impl ExecutionProvider for ACL {
fn name(&self) -> &'static str {
"ACLExecutionProvider"
}

View File

@@ -4,19 +4,19 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// [Arm NN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/ArmNN-ExecutionProvider.html)
/// for ARM platforms.
#[derive(Debug, Default, Clone)]
pub struct ArmNNExecutionProvider {
pub struct ArmNN {
use_arena: bool
}
super::impl_ep!(ArmNNExecutionProvider);
super::impl_ep!(ArmNN);
impl ArmNNExecutionProvider {
impl ArmNN {
/// Enable/disable the usage of the arena allocator.
///
/// ```
/// # use ort::{execution_providers::armnn::ArmNNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ArmNNExecutionProvider::default().with_arena_allocator(true).build();
/// let ep = ep::ArmNN::default().with_arena_allocator(true).build();
/// # Ok(())
/// # }
/// ```
@@ -27,7 +27,7 @@ impl ArmNNExecutionProvider {
}
}
impl ExecutionProvider for ArmNNExecutionProvider {
impl ExecutionProvider for ArmNN {
fn name(&self) -> &'static str {
"ArmNNExecutionProvider"
}

View File

@@ -35,12 +35,12 @@ use crate::{error::Result, session::builder::SessionBuilder};
///
/// To use this model in `ort`:
/// ```no_run
/// # use ort::{execution_providers::azure::AzureExecutionProvider, session::Session, value::Tensor};
/// # use ort::{ep, session::Session, value::Tensor};
/// # fn main() -> ort::Result<()> {
/// let mut session = Session::builder()?
/// // note: session must be initialized with `onnxruntime-extensions`
/// .with_extensions()?
/// .with_execution_providers([AzureExecutionProvider::default().build()])?
/// .with_execution_providers([ep::Azure::default().build()])?
/// .commit_from_file("azure_chat.onnx")?;
///
/// let auth_token = Tensor::from_string_array(([1], &*vec!["..."]))?;
@@ -65,13 +65,13 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// # }
/// ```
#[derive(Debug, Default, Clone)]
pub struct AzureExecutionProvider {
pub struct Azure {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; AzureExecutionProvider);
super::impl_ep!(arbitrary; Azure);
impl ExecutionProvider for AzureExecutionProvider {
impl ExecutionProvider for Azure {
fn name(&self) -> &'static str {
"AzureExecutionProvider"
}

View File

@@ -5,7 +5,7 @@ use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum CANNPrecisionMode {
pub enum PrecisionMode {
/// Convert to float32 first according to operator implementation
ForceFP32,
/// Convert to float16 when float16 and float32 are both supported
@@ -20,7 +20,7 @@ pub enum CANNPrecisionMode {
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum CANNImplementationMode {
pub enum ImplementationMode {
/// Prefer high precision, potentially at the cost of some performance.
HighPrecision,
/// Prefer high performance, potentially with lower accuracy.
@@ -30,19 +30,19 @@ pub enum CANNImplementationMode {
/// [CANN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html)
/// for hardware acceleration using Huawei Ascend AI devices.
#[derive(Default, Debug, Clone)]
pub struct CANNExecutionProvider {
pub struct CANN {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; CANNExecutionProvider);
super::impl_ep!(arbitrary; CANN);
impl CANNExecutionProvider {
impl CANN {
/// Configures which device the EP should use.
///
/// ```
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_device_id(0).build();
/// let ep = ep::CANN::default().with_device_id(0).build();
/// # Ok(())
/// # }
/// ```
@@ -56,9 +56,9 @@ impl CANNExecutionProvider {
/// providers arena; the total device memory usage may be higher.
///
/// ```
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
/// let ep = ep::CANN::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
/// # Ok(())
/// # }
/// ```
@@ -71,9 +71,9 @@ impl CANNExecutionProvider {
/// Configure the strategy for extending the device's memory arena.
///
/// ```
/// # use ort::{execution_providers::{cann::CANNExecutionProvider, ArenaExtendStrategy}, session::Session};
/// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default()
/// let ep = ep::CANN::default()
/// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested)
/// .build();
/// # Ok(())
@@ -95,9 +95,9 @@ impl CANNExecutionProvider {
/// is `true`. If `false`, it will fall back to the single-operator inference engine.
///
/// ```
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_cann_graph(true).build();
/// let ep = ep::CANN::default().with_cann_graph(true).build();
/// # Ok(())
/// # }
/// ```
@@ -110,9 +110,9 @@ impl CANNExecutionProvider {
/// Configure whether to dump the subgraph into ONNX format for analysis of subgraph segmentation.
///
/// ```
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_dump_graphs(true).build();
/// let ep = ep::CANN::default().with_dump_graphs(true).build();
/// # Ok(())
/// # }
/// ```
@@ -125,9 +125,9 @@ impl CANNExecutionProvider {
/// Configure whether to dump the offline model to an `.om` file.
///
/// ```
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_dump_om_model(true).build();
/// let ep = ep::CANN::default().with_dump_om_model(true).build();
/// # Ok(())
/// # }
/// ```
@@ -137,25 +137,25 @@ impl CANNExecutionProvider {
self
}
/// Configure the precision mode; see [`CANNPrecisionMode`].
/// Configure the precision mode; see [`PrecisionMode`].
///
/// ```
/// # use ort::{execution_providers::cann::{CANNExecutionProvider, CANNPrecisionMode}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_precision_mode(CANNPrecisionMode::ForceFP16).build();
/// let ep = ep::CANN::default().with_precision_mode(ep::cann::PrecisionMode::ForceFP16).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_precision_mode(mut self, mode: CANNPrecisionMode) -> Self {
pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
self.options.set(
"precision_mode",
match mode {
CANNPrecisionMode::ForceFP32 => "force_fp32",
CANNPrecisionMode::ForceFP16 => "force_fp16",
CANNPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
CANNPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
CANNPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
PrecisionMode::ForceFP32 => "force_fp32",
PrecisionMode::ForceFP16 => "force_fp16",
PrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
PrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
PrecisionMode::AllowMixedPrecision => "allow_mix_precision"
}
);
self
@@ -165,33 +165,33 @@ impl CANNExecutionProvider {
/// high-performance implementations.
///
/// ```
/// # use ort::{execution_providers::cann::{CANNExecutionProvider, CANNImplementationMode}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default()
/// .with_implementation_mode(CANNImplementationMode::HighPerformance)
/// let ep = ep::CANN::default()
/// .with_implementation_mode(ep::cann::ImplementationMode::HighPerformance)
/// .build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_implementation_mode(mut self, mode: CANNImplementationMode) -> Self {
pub fn with_implementation_mode(mut self, mode: ImplementationMode) -> Self {
self.options.set(
"op_select_impl_mode",
match mode {
CANNImplementationMode::HighPrecision => "high_precision",
CANNImplementationMode::HighPerformance => "high_performance"
ImplementationMode::HighPrecision => "high_precision",
ImplementationMode::HighPerformance => "high_performance"
}
);
self
}
/// Configure the list of operators which use the mode specified by
/// [`CANNExecutionProvider::with_implementation_mode`].
/// [`CANN::with_implementation_mode`].
///
/// ```
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CANNExecutionProvider::default().with_implementation_mode_oplist("LayerNorm,Gelu").build();
/// let ep = ep::CANN::default().with_implementation_mode_oplist("LayerNorm,Gelu").build();
/// # Ok(())
/// # }
/// ```
@@ -202,7 +202,7 @@ impl CANNExecutionProvider {
}
}
impl ExecutionProvider for CANNExecutionProvider {
impl ExecutionProvider for CANN {
fn name(&self) -> &'static str {
"CANNExecutionProvider"
}

View File

@@ -4,7 +4,7 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoreMLSpecializationStrategy {
pub enum SpecializationStrategy {
/// The strategy that should work well for most applications.
Default,
/// Prefer the prediction latency at the potential cost of specialization time, memory footprint, and the disk space
@@ -12,7 +12,7 @@ pub enum CoreMLSpecializationStrategy {
FastPrediction
}
impl CoreMLSpecializationStrategy {
impl SpecializationStrategy {
pub(crate) fn as_str(&self) -> &'static str {
match self {
Self::Default => "Default",
@@ -22,7 +22,7 @@ impl CoreMLSpecializationStrategy {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoreMLComputeUnits {
pub enum ComputeUnits {
/// Enable CoreML EP for all compatible Apple devices.
All,
/// Enable CoreML EP for Apple devices with a compatible Neural Engine (ANE).
@@ -33,7 +33,7 @@ pub enum CoreMLComputeUnits {
CPUOnly
}
impl CoreMLComputeUnits {
impl ComputeUnits {
pub(crate) fn as_str(&self) -> &'static str {
match self {
Self::All => "ALL",
@@ -45,14 +45,14 @@ impl CoreMLComputeUnits {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoreMLModelFormat {
pub enum ModelFormat {
/// Requires Core ML 5 or later (iOS 15+ or macOS 12+).
MLProgram,
/// Default; requires Core ML 3 or later (iOS 13+ or macOS 10.15+).
NeuralNetwork
}
impl CoreMLModelFormat {
impl ModelFormat {
pub(crate) fn as_str(&self) -> &'static str {
match self {
Self::MLProgram => "MLProgram",
@@ -64,20 +64,20 @@ impl CoreMLModelFormat {
/// [CoreML execution provider](https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html) for hardware
/// acceleration on Apple devices.
#[derive(Debug, Default, Clone)]
pub struct CoreMLExecutionProvider {
pub struct CoreML {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; CoreMLExecutionProvider);
super::impl_ep!(arbitrary; CoreML);
impl CoreMLExecutionProvider {
impl CoreML {
/// Enable CoreML EP to run on a subgraph in the body of a control flow operator (i.e. a `Loop`, `Scan` or `If`
/// operator).
///
/// ```
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_subgraphs(true).build();
/// let ep = ep::CoreML::default().with_subgraphs(true).build();
/// # Ok(())
/// # }
/// ```
@@ -91,9 +91,9 @@ impl CoreMLExecutionProvider {
/// allow inputs with dynamic shapes, however performance may be negatively impacted by inputs with dynamic shapes.
///
/// ```
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_static_input_shapes(true).build();
/// let ep = ep::CoreML::default().with_static_input_shapes(true).build();
/// # Ok(())
/// # }
/// ```
@@ -105,19 +105,19 @@ impl CoreMLExecutionProvider {
/// Configures the format of the CoreML model created by the EP.
///
/// The default format, [NeuralNetwork](`CoreMLModelFormat::NeuralNetwork`), has better compatibility with older
/// versions of macOS/iOS. The newer [MLProgram](`CoreMLModelFormat::MLProgram`) format supports more operators,
/// The default format, [NeuralNetwork](`ModelFormat::NeuralNetwork`), has better compatibility with older
/// versions of macOS/iOS. The newer [MLProgram](`ModelFormat::MLProgram`) format supports more operators,
/// and may be more performant.
///
/// ```
/// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLModelFormat}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_model_format(CoreMLModelFormat::MLProgram).build();
/// let ep = ep::CoreML::default().with_model_format(ep::coreml::ModelFormat::MLProgram).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_model_format(mut self, model_format: CoreMLModelFormat) -> Self {
pub fn with_model_format(mut self, model_format: ModelFormat) -> Self {
self.options.set("ModelFormat", model_format.as_str());
self
}
@@ -129,14 +129,16 @@ impl CoreMLExecutionProvider {
/// model for faster prediction, at the potential cost of session load time and memory footprint.
///
/// ```
/// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLSpecializationStrategy}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_specialization_strategy(CoreMLSpecializationStrategy::FastPrediction).build();
/// let ep = ep::CoreML::default()
/// .with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
/// .build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_specialization_strategy(mut self, strategy: CoreMLSpecializationStrategy) -> Self {
pub fn with_specialization_strategy(mut self, strategy: SpecializationStrategy) -> Self {
self.options.set("SpecializationStrategy", strategy.as_str());
self
}
@@ -144,16 +146,16 @@ impl CoreMLExecutionProvider {
/// Configures what hardware can be used by CoreML for acceleration.
///
/// ```
/// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLComputeUnits}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default()
/// .with_compute_units(CoreMLComputeUnits::CPUAndNeuralEngine)
/// let ep = ep::CoreML::default()
/// .with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
/// .build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_compute_units(mut self, units: CoreMLComputeUnits) -> Self {
pub fn with_compute_units(mut self, units: ComputeUnits) -> Self {
self.options.set("MLComputeUnits", units.as_str());
self
}
@@ -162,9 +164,9 @@ impl CoreMLExecutionProvider {
/// for debugging unexpected performance with CoreML.
///
/// ```
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_profile_compute_plan(true).build();
/// let ep = ep::CoreML::default().with_profile_compute_plan(true).build();
/// # Ok(())
/// # }
/// ```
@@ -177,9 +179,9 @@ impl CoreMLExecutionProvider {
/// Configures whether to allow low-precision (fp16) accumulation on GPU.
///
/// ```
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_low_precision_accumulation_on_gpu(true).build();
/// let ep = ep::CoreML::default().with_low_precision_accumulation_on_gpu(true).build();
/// # Ok(())
/// # }
/// ```
@@ -195,9 +197,9 @@ impl CoreMLExecutionProvider {
/// session. Setting this option allows the compiled model to be reused across session loads.
///
/// ```
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CoreMLExecutionProvider::default().with_model_cache_dir("/path/to/cache").build();
/// let ep = ep::CoreML::default().with_model_cache_dir("/path/to/cache").build();
/// # Ok(())
/// # }
/// ```
@@ -235,7 +237,7 @@ impl CoreMLExecutionProvider {
}
}
impl ExecutionProvider for CoreMLExecutionProvider {
impl ExecutionProvider for CoreML {
fn name(&self) -> &'static str {
"CoreMLExecutionProvider"
}

View File

@@ -3,19 +3,19 @@ use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};
/// The default CPU execution provider, powered by MLAS.
#[derive(Debug, Default, Clone)]
pub struct CPUExecutionProvider {
pub struct CPU {
use_arena: bool
}
super::impl_ep!(CPUExecutionProvider);
super::impl_ep!(CPU);
impl CPUExecutionProvider {
impl CPU {
/// Enable/disable the usage of the arena allocator.
///
/// ```
/// # use ort::{execution_providers::cpu::CPUExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CPUExecutionProvider::default().with_arena_allocator(true).build();
/// let ep = ep::CPU::default().with_arena_allocator(true).build();
/// # Ok(())
/// # }
/// ```
@@ -26,7 +26,7 @@ impl CPUExecutionProvider {
}
}
impl ExecutionProvider for CPUExecutionProvider {
impl ExecutionProvider for CPU {
fn name(&self) -> &'static str {
"CPUExecutionProvider"
}

View File

@@ -7,9 +7,9 @@ use crate::{error::Result, session::builder::SessionBuilder};
// https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/contrib_ops/cpu/bert/attention_common.h#L160-L171
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct CUDAAttentionBackend(u32);
pub struct AttentionBackend(u32);
impl CUDAAttentionBackend {
impl AttentionBackend {
pub const FLASH_ATTENTION: Self = Self(1 << 0);
pub const EFFICIENT_ATTENTION: Self = Self(1 << 1);
pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2);
@@ -23,7 +23,7 @@ impl CUDAAttentionBackend {
pub const LEAN_ATTENTION: Self = Self(1 << 8);
pub fn none() -> Self {
CUDAAttentionBackend(0)
AttentionBackend(0)
}
pub fn all() -> Self {
@@ -38,7 +38,7 @@ impl CUDAAttentionBackend {
}
}
impl BitOr for CUDAAttentionBackend {
impl BitOr for AttentionBackend {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(rhs.0 | self.0)
@@ -47,7 +47,7 @@ impl BitOr for CUDAAttentionBackend {
/// The type of search done for cuDNN convolution algorithms.
#[derive(Debug, Clone, Default)]
pub enum CuDNNConvAlgorithmSearch {
pub enum ConvAlgorithmSearch {
/// Expensive exhaustive benchmarking using [`cudnnFindConvolutionForwardAlgorithmEx`][exhaustive].
/// This function will attempt all possible algorithms for `cudnnConvolutionForward` to find the fastest algorithm.
/// Exhaustive search trades off between memory usage and speed. The first execution of a graph will be slow while
@@ -71,27 +71,27 @@ pub enum CuDNNConvAlgorithmSearch {
/// > search algorithm is actually [`Exhaustive`].
///
/// [fwdalgo]: https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionFwdAlgo_t
/// [`Exhaustive`]: CuDNNConvAlgorithmSearch::Exhaustive
/// [`Heuristic`]: CuDNNConvAlgorithmSearch::Heuristic
/// [`Exhaustive`]: ConvAlgorithmSearch::Exhaustive
/// [`Heuristic`]: ConvAlgorithmSearch::Heuristic
Default
}
/// [CUDA execution provider](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for NVIDIA
/// CUDA-enabled GPUs.
#[derive(Debug, Default, Clone)]
pub struct CUDAExecutionProvider {
pub struct CUDA {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; CUDAExecutionProvider);
super::impl_ep!(arbitrary; CUDA);
impl CUDAExecutionProvider {
impl CUDA {
/// Configures which device the EP should use.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_device_id(0).build();
/// let ep = ep::CUDA::default().with_device_id(0).build();
/// # Ok(())
/// # }
/// ```
@@ -104,12 +104,12 @@ impl CUDAExecutionProvider {
/// Configure the size limit of the device memory arena in bytes.
///
/// This only controls how much memory can be allocated to the *arena* - actual memory usage may be higher due to
/// internal CUDA allocations, like those required for different [`CuDNNConvAlgorithmSearch`] options.
/// internal CUDA allocations, like those required for different [`ConvAlgorithmSearch`] options.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
/// let ep = ep::CUDA::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
/// # Ok(())
/// # }
/// ```
@@ -122,9 +122,9 @@ impl CUDAExecutionProvider {
/// Configure the strategy for extending the device's memory arena.
///
/// ```
/// # use ort::{execution_providers::{cuda::CUDAExecutionProvider, ArenaExtendStrategy}, session::Session};
/// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default()
/// let ep = ep::CUDA::default()
/// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested)
/// .build();
/// # Ok(())
@@ -152,7 +152,7 @@ impl CUDAExecutionProvider {
/// The default search algorithm, [`Exhaustive`][exh], will benchmark all available implementations and use the most
/// performant one. This option is very resource intensive (both computationally on first run and peak-memory-wise),
/// but ensures best performance. It is roughly equivalent to setting `torch.backends.cudnn.benchmark = True` with
/// PyTorch. See also [`CUDAExecutionProvider::with_conv_max_workspace`] to configure how much memory the exhaustive
/// PyTorch. See also [`CUDA::with_conv_max_workspace`] to configure how much memory the exhaustive
/// search can use (the default is unlimited).
///
/// A less resource-intensive option is [`Heuristic`][heu]. Rather than benchmarking every implementation,
@@ -164,41 +164,41 @@ impl CUDAExecutionProvider {
/// is not the *default behavior* (that would be [`Exhaustive`][exh]).
///
/// ```
/// # use ort::{execution_providers::cuda::{CUDAExecutionProvider, CuDNNConvAlgorithmSearch}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default()
/// .with_conv_algorithm_search(CuDNNConvAlgorithmSearch::Heuristic)
/// let ep = ep::CUDA::default()
/// .with_conv_algorithm_search(ep::cuda::ConvAlgorithmSearch::Heuristic)
/// .build();
/// # Ok(())
/// # }
/// ```
///
/// [exh]: CuDNNConvAlgorithmSearch::Exhaustive
/// [heu]: CuDNNConvAlgorithmSearch::Heuristic
/// [def]: CuDNNConvAlgorithmSearch::Default
/// [exh]: ConvAlgorithmSearch::Exhaustive
/// [heu]: ConvAlgorithmSearch::Heuristic
/// [def]: ConvAlgorithmSearch::Default
#[must_use]
pub fn with_conv_algorithm_search(mut self, search: CuDNNConvAlgorithmSearch) -> Self {
pub fn with_conv_algorithm_search(mut self, search: ConvAlgorithmSearch) -> Self {
self.options.set(
"cudnn_conv_algo_search",
match search {
CuDNNConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE",
CuDNNConvAlgorithmSearch::Heuristic => "HEURISTIC",
CuDNNConvAlgorithmSearch::Default => "DEFAULT"
ConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE",
ConvAlgorithmSearch::Heuristic => "HEURISTIC",
ConvAlgorithmSearch::Default => "DEFAULT"
}
);
self
}
/// Configure whether the [`Exhaustive`][CuDNNConvAlgorithmSearch::Exhaustive] search can use as much memory as it
/// Configure whether the [`Exhaustive`][ConvAlgorithmSearch::Exhaustive] search can use as much memory as it
/// needs.
///
/// The default is `true`. When `false`, the memory used for the search is limited to 32 MB, which will impact its
/// ability to find an optimal convolution algorithm.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_conv_max_workspace(false).build();
/// let ep = ep::CUDA::default().with_conv_max_workspace(false).build();
/// # Ok(())
/// # }
/// ```
@@ -218,9 +218,9 @@ impl CUDAExecutionProvider {
/// convolution operations that do not use 3-dimensional input shapes, or the *result* of such operations.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_conv1d_pad_to_nc1d(true).build();
/// let ep = ep::CUDA::default().with_conv1d_pad_to_nc1d(true).build();
/// # Ok(())
/// # }
/// ```
@@ -245,9 +245,9 @@ impl CUDAExecutionProvider {
/// Consult the [ONNX Runtime documentation on CUDA graphs](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#using-cuda-graphs-preview) for more information.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_cuda_graph(true).build();
/// let ep = ep::CUDA::default().with_cuda_graph(true).build();
/// # Ok(())
/// # }
/// ```
@@ -262,9 +262,9 @@ impl CUDAExecutionProvider {
/// `SkipLayerNorm`'s strict mode trades performance for accuracy. The default is `false` (strict mode disabled).
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_skip_layer_norm_strict_mode(true).build();
/// let ep = ep::CUDA::default().with_skip_layer_norm_strict_mode(true).build();
/// # Ok(())
/// # }
/// ```
@@ -285,9 +285,9 @@ impl CUDAExecutionProvider {
/// `torch.backends.cuda.matmul.allow_tf32 = True` or `torch.set_float32_matmul_precision("medium")` in PyTorch.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_tf32(true).build();
/// let ep = ep::CUDA::default().with_tf32(true).build();
/// # Ok(())
/// # }
/// ```
@@ -303,9 +303,9 @@ impl CUDAExecutionProvider {
/// convolution-heavy models on Tensor core-enabled GPUs may provide a significant performance improvement.
///
/// ```
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default().with_prefer_nhwc(true).build();
/// let ep = ep::CUDA::default().with_prefer_nhwc(true).build();
/// # Ok(())
/// # }
/// ```
@@ -328,16 +328,18 @@ impl CUDAExecutionProvider {
/// Configures the available backends used for `Attention` nodes.
///
/// ```
/// # use ort::{execution_providers::cuda::{CUDAExecutionProvider, CUDAAttentionBackend}, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = CUDAExecutionProvider::default()
/// .with_attention_backend(CUDAAttentionBackend::FLASH_ATTENTION | CUDAAttentionBackend::TRT_FUSED_ATTENTION)
/// let ep = ep::CUDA::default()
/// .with_attention_backend(
/// ep::cuda::AttentionBackend::FLASH_ATTENTION | ep::cuda::AttentionBackend::TRT_FUSED_ATTENTION
/// )
/// .build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_attention_backend(mut self, flags: CUDAAttentionBackend) -> Self {
pub fn with_attention_backend(mut self, flags: AttentionBackend) -> Self {
self.options.set("sdpa_kernel", flags.0.to_string());
self
}
@@ -352,7 +354,7 @@ impl CUDAExecutionProvider {
// https://github.com/microsoft/onnxruntime/blob/fe8a10caa40f64a8fbd144e7049cf5b14c65542d/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc#L17
}
impl ExecutionProvider for CUDAExecutionProvider {
impl ExecutionProvider for CUDA {
fn name(&self) -> &'static str {
"CUDAExecutionProvider"
}
@@ -435,16 +437,16 @@ pub const CUDNN_DYLIBS: &[&str] = &[
///
/// ```
/// # use std::path::Path;
/// use ort::execution_providers::cuda;
/// use ort::ep;
///
/// let cuda_root = Path::new(r#"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin"#);
/// let cudnn_root = Path::new(r#"D:\cudnn_9.8.0"#);
///
/// // Load CUDA & cuDNN
/// let _ = cuda::preload_dylibs(Some(cuda_root), Some(cudnn_root));
/// let _ = ep::cuda::preload_dylibs(Some(cuda_root), Some(cudnn_root));
///
/// // Only preload cuDNN
/// let _ = cuda::preload_dylibs(None, Some(cudnn_root));
/// let _ = ep::cuda::preload_dylibs(None, Some(cudnn_root));
/// ```
#[cfg_attr(docsrs, doc(cfg(any(feature = "preload-dylibs", feature = "load-dynamic"))))]
#[cfg(feature = "preload-dylibs")]

View File

@@ -9,10 +9,10 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// with dynamically sized inputs, you can override individual dimensions by constructing the session with
/// [`SessionBuilder::with_dimension_override`]:
/// ```no_run
/// # use ort::{execution_providers::directml::DirectMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?
/// .with_execution_providers([DirectMLExecutionProvider::default().build()])?
/// .with_execution_providers([ep::DirectML::default().build()])?
/// .with_dimension_override("batch", 1)?
/// .with_dimension_override("seq_len", 512)?
/// .commit_from_file("gpt2.onnx")?;
@@ -20,19 +20,19 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// # }
/// ```
#[derive(Debug, Default, Clone)]
pub struct DirectMLExecutionProvider {
pub struct DirectML {
device_id: i32
}
super::impl_ep!(DirectMLExecutionProvider);
super::impl_ep!(DirectML);
impl DirectMLExecutionProvider {
impl DirectML {
/// Configures which device the EP should use.
///
/// ```
/// # use ort::{execution_providers::directml::DirectMLExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = DirectMLExecutionProvider::default().with_device_id(1).build();
/// let ep = ep::DirectML::default().with_device_id(1).build();
/// # Ok(())
/// # }
/// ```
@@ -43,7 +43,7 @@ impl DirectMLExecutionProvider {
}
}
impl ExecutionProvider for DirectMLExecutionProvider {
impl ExecutionProvider for DirectML {
fn name(&self) -> &'static str {
"DmlExecutionProvider"
}

View File

@@ -6,7 +6,7 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// [MIGraphX execution provider](https://onnxruntime.ai/docs/execution-providers/MIGraphX-ExecutionProvider.html) for
/// hardware acceleration with AMD GPUs.
#[derive(Debug, Default, Clone)]
pub struct MIGraphXExecutionProvider {
pub struct MIGraphX {
device_id: i32,
enable_fp16: bool,
enable_int8: bool,
@@ -17,15 +17,15 @@ pub struct MIGraphXExecutionProvider {
exhaustive_tune: bool
}
super::impl_ep!(MIGraphXExecutionProvider);
super::impl_ep!(MIGraphX);
impl MIGraphXExecutionProvider {
impl MIGraphX {
/// Configures which device the EP should use.
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default().with_device_id(0).build();
/// let ep = ep::MIGraphX::default().with_device_id(0).build();
/// # Ok(())
/// # }
/// ```
@@ -38,9 +38,9 @@ impl MIGraphXExecutionProvider {
/// Enable FP16 quantization for the model.
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default().with_fp16(true).build();
/// let ep = ep::MIGraphX::default().with_fp16(true).build();
/// # Ok(())
/// # }
/// ```
@@ -51,15 +51,12 @@ impl MIGraphXExecutionProvider {
}
/// Enable 8-bit integer quantization for the model. Requires
/// [`MIGraphXExecutionProvider::with_int8_calibration_table`] to be set.
/// [`MIGraphX::with_int8_calibration_table`] to be set.
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default()
/// .with_int8(true)
/// .with_int8_calibration_table("...", false)
/// .build();
/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
/// # Ok(())
/// # }
/// ```
@@ -75,12 +72,9 @@ impl MIGraphXExecutionProvider {
/// for the JSON dump format.
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default()
/// .with_int8(true)
/// .with_int8_calibration_table("...", false)
/// .build();
/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
/// # Ok(())
/// # }
/// ```
@@ -93,12 +87,12 @@ impl MIGraphXExecutionProvider {
/// Save the compiled MIGraphX model to the given path.
///
/// The compiled model can then be loaded in subsequent runs with [`MIGraphXExecutionProvider::with_load_model`].
/// The compiled model can then be loaded in subsequent runs with [`MIGraphX::with_load_model`].
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default().with_save_model("./compiled_model.mxr").build();
/// let ep = ep::MIGraphX::default().with_save_model("./compiled_model.mxr").build();
/// # Ok(())
/// # }
/// ```
@@ -108,13 +102,13 @@ impl MIGraphXExecutionProvider {
self
}
/// Load the compiled MIGraphX model (previously generated by [`MIGraphXExecutionProvider::with_save_model`]) from
/// Load the compiled MIGraphX model (previously generated by [`MIGraphX::with_save_model`]) from
/// the given path.
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default().with_load_model("./compiled_model.mxr").build();
/// let ep = ep::MIGraphX::default().with_load_model("./compiled_model.mxr").build();
/// # Ok(())
/// # }
/// ```
@@ -127,9 +121,9 @@ impl MIGraphXExecutionProvider {
/// Enable exhaustive tuning; trades loading time for inference performance.
///
/// ```
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = MIGraphXExecutionProvider::default().with_exhaustive_tune(true).build();
/// let ep = ep::MIGraphX::default().with_exhaustive_tune(true).build();
/// # Ok(())
/// # }
/// ```
@@ -140,7 +134,7 @@ impl MIGraphXExecutionProvider {
}
}
impl ExecutionProvider for MIGraphXExecutionProvider {
impl ExecutionProvider for MIGraphX {
fn name(&self) -> &'static str {
"MIGraphXExecutionProvider"
}

View File

@@ -3,11 +3,11 @@
//! Sessions can be configured with execution providers via [`SessionBuilder::with_execution_providers`]:
//!
//! ```no_run
//! use ort::{execution_providers::CUDAExecutionProvider, session::Session};
//! use ort::{ep, session::Session};
//!
//! fn main() -> ort::Result<()> {
//! let session = Session::builder()?
//! .with_execution_providers([CUDAExecutionProvider::default().build()])?
//! .with_execution_providers([ep::CUDA::default().build()])?
//! .commit_from_file("model.onnx")?;
//!
//! Ok(())
@@ -24,47 +24,47 @@ use core::{
use crate::{char_p_to_string, error::Result, ortsys, session::builder::SessionBuilder, util::MiniMap};
pub mod cpu;
pub use self::cpu::CPUExecutionProvider;
pub use self::cpu::CPU;
pub mod cuda;
pub use self::cuda::CUDAExecutionProvider;
pub use self::cuda::CUDA;
pub mod tensorrt;
pub use self::tensorrt::TensorRTExecutionProvider;
pub use self::tensorrt::TensorRT;
pub mod onednn;
pub use self::onednn::OneDNNExecutionProvider;
pub use self::onednn::OneDNN;
pub mod acl;
pub use self::acl::ACLExecutionProvider;
pub use self::acl::ACL;
pub mod openvino;
pub use self::openvino::OpenVINOExecutionProvider;
pub use self::openvino::OpenVINO;
pub mod coreml;
pub use self::coreml::CoreMLExecutionProvider;
pub use self::coreml::CoreML;
pub mod rocm;
pub use self::rocm::ROCmExecutionProvider;
pub use self::rocm::ROCm;
pub mod cann;
pub use self::cann::CANNExecutionProvider;
pub use self::cann::CANN;
pub mod directml;
pub use self::directml::DirectMLExecutionProvider;
pub use self::directml::DirectML;
pub mod tvm;
pub use self::tvm::TVMExecutionProvider;
pub use self::tvm::TVM;
pub mod nnapi;
pub use self::nnapi::NNAPIExecutionProvider;
pub use self::nnapi::NNAPI;
pub mod qnn;
pub use self::qnn::QNNExecutionProvider;
pub use self::qnn::QNN;
pub mod xnnpack;
pub use self::xnnpack::XNNPACKExecutionProvider;
pub use self::xnnpack::XNNPACK;
pub mod armnn;
pub use self::armnn::ArmNNExecutionProvider;
pub use self::armnn::ArmNN;
pub mod migraphx;
pub use self::migraphx::MIGraphXExecutionProvider;
pub use self::migraphx::MIGraphX;
pub mod vitis;
pub use self::vitis::VitisAIExecutionProvider;
pub use self::vitis::Vitis;
pub mod rknpu;
pub use self::rknpu::RKNPUExecutionProvider;
pub use self::rknpu::RKNPU;
pub mod webgpu;
pub use self::webgpu::WebGPUExecutionProvider;
pub use self::webgpu::WebGPU;
pub mod azure;
pub use self::azure::AzureExecutionProvider;
pub use self::azure::Azure;
pub mod nvrtx;
pub use self::nvrtx::NVRTXExecutionProvider;
pub use self::nvrtx::NVRTX;
#[cfg(target_arch = "wasm32")]
pub mod wasm;
#[cfg(target_arch = "wasm32")]
@@ -82,14 +82,14 @@ pub trait ExecutionProvider: Send + Sync {
/// Returns the identifier of this execution provider used internally by ONNX Runtime.
///
/// This is the same as what's used in ONNX Runtime's Python API to register this execution provider, i.e.
/// [`TVMExecutionProvider`]'s identifier is `TvmExecutionProvider`.
/// [`TVM`]'s identifier is `TvmExecutionProvider`.
fn name(&self) -> &'static str;
/// Returns whether this execution provider is supported on this platform.
///
/// For example, the CoreML execution provider implements this as:
/// ```ignore
/// impl ExecutionProvider for CoreMLExecutionProvider {
/// impl ExecutionProvider for CoreML {
/// fn supported_by_platform() -> bool {
/// cfg!(target_vendor = "apple")
/// }
@@ -321,9 +321,9 @@ pub(crate) use define_ep_register;
macro_rules! impl_ep {
(arbitrary; $symbol:ident) => {
$crate::execution_providers::impl_ep!($symbol);
$crate::ep::impl_ep!($symbol);
impl $crate::execution_providers::ArbitrarilyConfigurableExecutionProvider for $symbol {
impl $crate::ep::ArbitrarilyConfigurableExecutionProvider for $symbol {
fn with_arbitrary_config(mut self, key: impl ::alloc::string::ToString, value: impl ::alloc::string::ToString) -> Self {
self.options.set(key.to_string(), value.to_string());
self
@@ -333,14 +333,14 @@ macro_rules! impl_ep {
($symbol:ident) => {
impl $symbol {
#[must_use]
pub fn build(self) -> $crate::execution_providers::ExecutionProviderDispatch {
pub fn build(self) -> $crate::ep::ExecutionProviderDispatch {
self.into()
}
}
impl From<$symbol> for $crate::execution_providers::ExecutionProviderDispatch {
impl From<$symbol> for $crate::ep::ExecutionProviderDispatch {
fn from(value: $symbol) -> Self {
$crate::execution_providers::ExecutionProviderDispatch::new(value)
$crate::ep::ExecutionProviderDispatch::new(value)
}
}
};
@@ -381,3 +381,75 @@ pub(crate) fn apply_execution_providers(session_builder: &mut SessionBuilder, ep
}
Ok(())
}
#[deprecated = "import `ort::ep::ACL` instead"]
#[doc(hidden)]
pub use self::acl::ACL as ACLExecutionProvider;
#[deprecated = "import `ort::ep::ArmNN` instead"]
#[doc(hidden)]
pub use self::armnn::ArmNN as ArmNNExecutionProvider;
#[deprecated = "import `ort::ep::Azure` instead"]
#[doc(hidden)]
pub use self::azure::Azure as AzureExecutionProvider;
#[deprecated = "import `ort::ep::CANN` instead"]
#[doc(hidden)]
pub use self::cann::CANN as CANNExecutionProvider;
#[deprecated = "import `ort::ep::CoreML` instead"]
#[doc(hidden)]
pub use self::coreml::CoreML as CoreMLExecutionProvider;
#[deprecated = "import `ort::ep::CPU` instead"]
#[doc(hidden)]
pub use self::cpu::CPU as CPUExecutionProvider;
#[deprecated = "import `ort::ep::CUDA` instead"]
#[doc(hidden)]
pub use self::cuda::CUDA as CUDAExecutionProvider;
#[deprecated = "import `ort::ep::DirectML` instead"]
#[doc(hidden)]
pub use self::directml::DirectML as DirectMLExecutionProvider;
#[deprecated = "import `ort::ep::MIGraphX` instead"]
#[doc(hidden)]
pub use self::migraphx::MIGraphX as MIGraphXExecutionProvider;
#[deprecated = "import `ort::ep::NNAPI` instead"]
#[doc(hidden)]
pub use self::nnapi::NNAPI as NNAPIExecutionProvider;
#[deprecated = "import `ort::ep::NVRTX` instead"]
#[doc(hidden)]
pub use self::nvrtx::NVRTX as NVRTXExecutionProvider;
#[deprecated = "import `ort::ep::OneDNN` instead"]
#[doc(hidden)]
pub use self::onednn::OneDNN as OneDNNExecutionProvider;
#[deprecated = "import `ort::ep::OpenVINO` instead"]
#[doc(hidden)]
pub use self::openvino::OpenVINO as OpenVINOExecutionProvider;
#[deprecated = "import `ort::ep::QNN` instead"]
#[doc(hidden)]
pub use self::qnn::QNN as QNNExecutionProvider;
#[deprecated = "import `ort::ep::RKNPU` instead"]
#[doc(hidden)]
pub use self::rknpu::RKNPU as RKNPUExecutionProvider;
#[deprecated = "import `ort::ep::ROCm` instead"]
#[doc(hidden)]
pub use self::rocm::ROCm as ROCmExecutionProvider;
#[deprecated = "import `ort::ep::TensorRT` instead"]
#[doc(hidden)]
pub use self::tensorrt::TensorRT as TensorRTExecutionProvider;
#[deprecated = "import `ort::ep::TVM` instead"]
#[doc(hidden)]
pub use self::tvm::TVM as TVMExecutionProvider;
#[deprecated = "import `ort::ep::Vitis` instead"]
#[doc(hidden)]
pub use self::vitis::Vitis as VitisAIExecutionProvider;
#[deprecated = "import `ort::ep::WASM` instead"]
#[doc(hidden)]
#[cfg(target_arch = "wasm32")]
pub use self::wasm::WASM as WASMExecutionProvider;
#[deprecated = "import `ort::ep::WebGPU` instead"]
#[doc(hidden)]
pub use self::webgpu::WebGPU as WebGPUExecutionProvider;
#[deprecated = "import `ort::ep::WebNN` instead"]
#[doc(hidden)]
#[cfg(target_arch = "wasm32")]
pub use self::webnn::WebNN as WebNNExecutionProvider;
#[deprecated = "import `ort::ep::XNNPACK` instead"]
#[doc(hidden)]
pub use self::xnnpack::XNNPACK as XNNPACKExecutionProvider;

View File

@@ -2,16 +2,16 @@ use super::{ExecutionProvider, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct NNAPIExecutionProvider {
pub struct NNAPI {
use_fp16: bool,
use_nchw: bool,
disable_cpu: bool,
cpu_only: bool
}
super::impl_ep!(NNAPIExecutionProvider);
super::impl_ep!(NNAPI);
impl NNAPIExecutionProvider {
impl NNAPI {
/// Use fp16 relaxation in NNAPI EP. This may improve performance but can also reduce accuracy due to the lower
/// precision.
#[must_use]
@@ -49,7 +49,7 @@ impl NNAPIExecutionProvider {
}
}
impl ExecutionProvider for NNAPIExecutionProvider {
impl ExecutionProvider for NNAPI {
fn name(&self) -> &'static str {
"NnapiExecutionProvider"
}

View File

@@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct NVRTXExecutionProvider {
pub struct NVRTX {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; NVRTXExecutionProvider);
super::impl_ep!(arbitrary; NVRTX);
impl NVRTXExecutionProvider {
impl NVRTX {
pub fn with_device_id(mut self, device_id: u32) -> Self {
self.options.set("ep.nvtensorrtrtxexecutionprovider.device_id", device_id.to_string());
self
@@ -23,7 +23,7 @@ impl NVRTXExecutionProvider {
}
}
impl ExecutionProvider for NVRTXExecutionProvider {
impl ExecutionProvider for NVRTX {
fn name(&self) -> &'static str {
"NvTensorRTRTXExecutionProvider"
}

View File

@@ -4,20 +4,20 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// [oneDNN/DNNL execution provider](https://onnxruntime.ai/docs/execution-providers/oneDNN-ExecutionProvider.html) for
/// Intel CPUs & iGPUs.
#[derive(Debug, Default, Clone)]
#[doc(alias = "DNNLExecutionProvider")]
pub struct OneDNNExecutionProvider {
#[doc(alias = "DNNL")]
pub struct OneDNN {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; OneDNNExecutionProvider);
super::impl_ep!(arbitrary; OneDNN);
impl OneDNNExecutionProvider {
impl OneDNN {
/// Enable/disable the usage of the arena allocator.
///
/// ```
/// # use ort::{execution_providers::onednn::OneDNNExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = OneDNNExecutionProvider::default().with_arena_allocator(true).build();
/// let ep = ep::OneDNN::default().with_arena_allocator(true).build();
/// # Ok(())
/// # }
/// ```
@@ -28,7 +28,7 @@ impl OneDNNExecutionProvider {
}
}
impl ExecutionProvider for OneDNNExecutionProvider {
impl ExecutionProvider for OneDNN {
fn name(&self) -> &'static str {
"DnnlExecutionProvider"
}

View File

@@ -6,22 +6,22 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// [OpenVINO execution provider](https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html) for
/// Intel CPUs/GPUs/NPUs.
#[derive(Default, Debug, Clone)]
pub struct OpenVINOExecutionProvider {
pub struct OpenVINO {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; OpenVINOExecutionProvider);
super::impl_ep!(arbitrary; OpenVINO);
impl OpenVINOExecutionProvider {
impl OpenVINO {
/// Overrides the accelerator hardware type and precision.
///
/// `device_type` should be in the format `CPU`, `NPU`, `GPU`, `GPU.0`, `GPU.1`, etc. Heterogenous combinations are
/// also supported in the format `HETERO:NPU,GPU`.
///
/// ```
/// # use ort::{execution_providers::openvino::OpenVINOExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = OpenVINOExecutionProvider::default().with_device_type("GPU.0").build();
/// let ep = ep::OpenVINO::default().with_device_type("GPU.0").build();
/// # Ok(())
/// # }
/// ```
@@ -81,7 +81,7 @@ impl OpenVINOExecutionProvider {
}
}
impl ExecutionProvider for OpenVINOExecutionProvider {
impl ExecutionProvider for OpenVINO {
fn name(&self) -> &'static str {
"OpenVINOExecutionProvider"
}

View File

@@ -4,7 +4,7 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QNNPerformanceMode {
pub enum PerformanceMode {
Default,
Burst,
Balanced,
@@ -17,43 +17,43 @@ pub enum QNNPerformanceMode {
SustainedHighPerformance
}
impl QNNPerformanceMode {
impl PerformanceMode {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
QNNPerformanceMode::Default => "default",
QNNPerformanceMode::Burst => "burst",
QNNPerformanceMode::Balanced => "balanced",
QNNPerformanceMode::HighPerformance => "high_performance",
QNNPerformanceMode::HighPowerSaver => "high_power_saver",
QNNPerformanceMode::LowPowerSaver => "low_power_saver",
QNNPerformanceMode::LowBalanced => "low_balanced",
QNNPerformanceMode::PowerSaver => "power_saver",
QNNPerformanceMode::ExtremePowerSaver => "extreme_power_saver",
QNNPerformanceMode::SustainedHighPerformance => "sustained_high_performance"
PerformanceMode::Default => "default",
PerformanceMode::Burst => "burst",
PerformanceMode::Balanced => "balanced",
PerformanceMode::HighPerformance => "high_performance",
PerformanceMode::HighPowerSaver => "high_power_saver",
PerformanceMode::LowPowerSaver => "low_power_saver",
PerformanceMode::LowBalanced => "low_balanced",
PerformanceMode::PowerSaver => "power_saver",
PerformanceMode::ExtremePowerSaver => "extreme_power_saver",
PerformanceMode::SustainedHighPerformance => "sustained_high_performance"
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QNNProfilingLevel {
pub enum ProfilingLevel {
Off,
Basic,
Detailed
}
impl QNNProfilingLevel {
impl ProfilingLevel {
pub fn as_str(&self) -> &'static str {
match self {
QNNProfilingLevel::Off => "off",
QNNProfilingLevel::Basic => "basic",
QNNProfilingLevel::Detailed => "detailed"
ProfilingLevel::Off => "off",
ProfilingLevel::Basic => "basic",
ProfilingLevel::Detailed => "detailed"
}
}
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum QNNContextPriority {
pub enum ContextPriority {
Low,
#[default]
Normal,
@@ -61,25 +61,25 @@ pub enum QNNContextPriority {
High
}
impl QNNContextPriority {
impl ContextPriority {
pub fn as_str(&self) -> &'static str {
match self {
QNNContextPriority::Low => "low",
QNNContextPriority::Normal => "normal",
QNNContextPriority::NormalHigh => "normal_high",
QNNContextPriority::High => "high"
ContextPriority::Low => "low",
ContextPriority::Normal => "normal",
ContextPriority::NormalHigh => "normal_high",
ContextPriority::High => "high"
}
}
}
#[derive(Debug, Default, Clone)]
pub struct QNNExecutionProvider {
pub struct QNN {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; QNNExecutionProvider);
super::impl_ep!(arbitrary; QNN);
impl QNNExecutionProvider {
impl QNN {
/// The file path to QNN backend library. On Linux/Android, this is `libQnnCpu.so` to use the CPU backend,
/// or `libQnnHtp.so` to use the accelerated backend.
#[must_use]
@@ -89,7 +89,7 @@ impl QNNExecutionProvider {
}
#[must_use]
pub fn with_profiling(mut self, level: QNNProfilingLevel) -> Self {
pub fn with_profiling(mut self, level: ProfilingLevel) -> Self {
self.options.set("profiling_level", level.as_str());
self
}
@@ -114,7 +114,7 @@ impl QNNExecutionProvider {
}
#[must_use]
pub fn with_performance_mode(mut self, mode: QNNPerformanceMode) -> Self {
pub fn with_performance_mode(mut self, mode: PerformanceMode) -> Self {
self.options.set("htp_performance_mode", mode.as_str());
self
}
@@ -126,7 +126,7 @@ impl QNNExecutionProvider {
}
#[must_use]
pub fn with_context_priority(mut self, priority: QNNContextPriority) -> Self {
pub fn with_context_priority(mut self, priority: ContextPriority) -> Self {
self.options.set("qnn_context_priority", priority.as_str());
self
}
@@ -174,7 +174,7 @@ impl QNNExecutionProvider {
}
}
impl ExecutionProvider for QNNExecutionProvider {
impl ExecutionProvider for QNN {
fn name(&self) -> &'static str {
"QNNExecutionProvider"
}

View File

@@ -2,11 +2,11 @@ use super::{ExecutionProvider, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct RKNPUExecutionProvider {}
pub struct RKNPU {}
super::impl_ep!(RKNPUExecutionProvider);
super::impl_ep!(RKNPU);
impl ExecutionProvider for RKNPUExecutionProvider {
impl ExecutionProvider for RKNPU {
fn name(&self) -> &'static str {
"RknpuExecutionProvider"
}

View File

@@ -5,13 +5,13 @@ use super::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderOptions, Re
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct ROCmExecutionProvider {
pub struct ROCm {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; ROCmExecutionProvider);
super::impl_ep!(arbitrary; ROCm);
impl ROCmExecutionProvider {
impl ROCm {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.options.set("device_id", device_id.to_string());
@@ -86,7 +86,7 @@ impl ROCmExecutionProvider {
}
}
impl ExecutionProvider for ROCmExecutionProvider {
impl ExecutionProvider for ROCm {
fn name(&self) -> &'static str {
"ROCMExecutionProvider"
}

View File

@@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct TensorRTExecutionProvider {
pub struct TensorRT {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; TensorRTExecutionProvider);
super::impl_ep!(arbitrary; TensorRT);
impl TensorRTExecutionProvider {
impl TensorRT {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.options.set("device_id", device_id.to_string());
@@ -254,7 +254,7 @@ impl TensorRTExecutionProvider {
}
}
impl ExecutionProvider for TensorRTExecutionProvider {
impl ExecutionProvider for TensorRT {
fn name(&self) -> &'static str {
"TensorrtExecutionProvider"
}

View File

@@ -4,22 +4,22 @@ use super::{ExecutionProvider, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TVMExecutorType {
pub enum ExecutorType {
GraphExecutor,
VirtualMachine
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TVMTuningType {
pub enum TuningType {
AutoTVM,
Ansor
}
#[derive(Debug, Default, Clone)]
pub struct TVMExecutionProvider {
pub struct TVM {
/// Executor type used by TVM. There is a choice between two types, `GraphExecutor` and `VirtualMachine`. Default is
/// [`TVMExecutorType::VirtualMachine`].
pub executor: Option<TVMExecutorType>,
/// [`ExecutorType::VirtualMachine`].
pub executor: Option<ExecutorType>,
/// Path to folder with set of files (`.ro-`, `.so`/`.dll`-files and weights) obtained after model tuning.
pub so_folder: Option<String>,
/// Whether or not to perform a hash check on the model obtained in the `so_folder`.
@@ -34,7 +34,7 @@ pub struct TVMExecutionProvider {
/// `true` is recommended for best performance and is the default.
pub freeze_weights: Option<bool>,
pub to_nhwc: Option<bool>,
pub tuning_type: Option<TVMTuningType>,
pub tuning_type: Option<TuningType>,
/// Path to AutoTVM or Ansor tuning file which gives specifications for given model and target for the best
/// performance.
pub tuning_file_path: Option<String>,
@@ -42,9 +42,9 @@ pub struct TVMExecutionProvider {
pub input_shapes: Option<String>
}
super::impl_ep!(TVMExecutionProvider);
super::impl_ep!(TVM);
impl ExecutionProvider for TVMExecutionProvider {
impl ExecutionProvider for TVM {
fn name(&self) -> &'static str {
"TvmExecutionProvider"
}
@@ -66,8 +66,8 @@ impl ExecutionProvider for TVMExecutionProvider {
option_string.push(format!(
"executor:{}",
match executor {
TVMExecutorType::GraphExecutor => "graph",
TVMExecutorType::VirtualMachine => "vm"
ExecutorType::GraphExecutor => "graph",
ExecutorType::VirtualMachine => "vm"
}
));
}

View File

@@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct VitisAIExecutionProvider {
pub struct Vitis {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; VitisAIExecutionProvider);
super::impl_ep!(arbitrary; Vitis);
impl VitisAIExecutionProvider {
impl Vitis {
pub fn with_config_file(mut self, config_file: impl ToString) -> Self {
self.options.set("config_file", config_file.to_string());
self
@@ -27,7 +27,7 @@ impl VitisAIExecutionProvider {
}
}
impl ExecutionProvider for VitisAIExecutionProvider {
impl ExecutionProvider for Vitis {
fn name(&self) -> &'static str {
"VitisAIExecutionProvider"
}

View File

@@ -2,13 +2,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};
#[derive(Debug, Default, Clone)]
pub struct WASMExecutionProvider {
pub struct WASM {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; WASMExecutionProvider);
super::impl_ep!(arbitrary; WASM);
impl ExecutionProvider for WASMExecutionProvider {
impl ExecutionProvider for WASM {
fn name(&self) -> &'static str {
"WASMExecutionProvider"
}

View File

@@ -4,84 +4,84 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebGPUPreferredLayout {
pub enum PreferredLayout {
NCHW,
NHWC
}
impl WebGPUPreferredLayout {
impl PreferredLayout {
pub(crate) fn as_str(&self) -> &'static str {
match self {
WebGPUPreferredLayout::NCHW => "NCHW",
WebGPUPreferredLayout::NHWC => "NHWC"
PreferredLayout::NCHW => "NCHW",
PreferredLayout::NHWC => "NHWC"
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebGPUDawnBackendType {
pub enum DawnBackendType {
Vulkan,
D3D12
}
impl WebGPUDawnBackendType {
impl DawnBackendType {
pub(crate) fn as_str(&self) -> &'static str {
match self {
WebGPUDawnBackendType::Vulkan => "Vulkan",
WebGPUDawnBackendType::D3D12 => "D3D12"
DawnBackendType::Vulkan => "Vulkan",
DawnBackendType::D3D12 => "D3D12"
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebGPUBufferCacheMode {
pub enum BufferCacheMode {
Disabled,
LazyRelease,
Simple,
Bucket
}
impl WebGPUBufferCacheMode {
impl BufferCacheMode {
pub(crate) fn as_str(&self) -> &'static str {
match self {
WebGPUBufferCacheMode::Disabled => "disabled",
WebGPUBufferCacheMode::LazyRelease => "lazyRelease",
WebGPUBufferCacheMode::Simple => "simple",
WebGPUBufferCacheMode::Bucket => "bucket"
BufferCacheMode::Disabled => "disabled",
BufferCacheMode::LazyRelease => "lazyRelease",
BufferCacheMode::Simple => "simple",
BufferCacheMode::Bucket => "bucket"
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebGPUValidationMode {
pub enum ValidationMode {
Disabled,
WgpuOnly,
Basic,
Full
}
impl WebGPUValidationMode {
impl ValidationMode {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
WebGPUValidationMode::Disabled => "disabled",
WebGPUValidationMode::WgpuOnly => "wgpuOnly",
WebGPUValidationMode::Basic => "basic",
WebGPUValidationMode::Full => "full"
ValidationMode::Disabled => "disabled",
ValidationMode::WgpuOnly => "wgpuOnly",
ValidationMode::Basic => "basic",
ValidationMode::Full => "full"
}
}
}
#[derive(Debug, Default, Clone)]
pub struct WebGPUExecutionProvider {
pub struct WebGPU {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; WebGPUExecutionProvider);
super::impl_ep!(arbitrary; WebGPU);
impl WebGPUExecutionProvider {
impl WebGPU {
#[must_use]
pub fn with_preferred_layout(mut self, layout: WebGPUPreferredLayout) -> Self {
pub fn with_preferred_layout(mut self, layout: PreferredLayout) -> Self {
self.options.set("ep.webgpuexecutionprovider.preferredLayout", layout.as_str());
self
}
@@ -100,7 +100,7 @@ impl WebGPUExecutionProvider {
}
#[must_use]
pub fn with_dawn_backend_type(mut self, backend_type: WebGPUDawnBackendType) -> Self {
pub fn with_dawn_backend_type(mut self, backend_type: DawnBackendType) -> Self {
self.options.set("ep.webgpuexecutionprovider.dawnBackendType", backend_type.as_str());
self
}
@@ -112,31 +112,31 @@ impl WebGPUExecutionProvider {
}
#[must_use]
pub fn with_storage_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
pub fn with_storage_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
self.options.set("ep.webgpuexecutionprovider.storageBufferCacheMode", mode.as_str());
self
}
#[must_use]
pub fn with_uniform_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
pub fn with_uniform_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
self.options.set("ep.webgpuexecutionprovider.uniformBufferCacheMode", mode.as_str());
self
}
#[must_use]
pub fn with_query_resolve_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
pub fn with_query_resolve_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
self.options.set("ep.webgpuexecutionprovider.queryResolveBufferCacheMode", mode.as_str());
self
}
#[must_use]
pub fn with_default_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
pub fn with_default_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
self.options.set("ep.webgpuexecutionprovider.defaultBufferCacheMode", mode.as_str());
self
}
#[must_use]
pub fn with_validation_mode(mut self, mode: WebGPUValidationMode) -> Self {
pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self {
self.options.set("ep.webgpuexecutionprovider.validationMode", mode.as_str());
self
}
@@ -155,7 +155,7 @@ impl WebGPUExecutionProvider {
}
}
impl ExecutionProvider for WebGPUExecutionProvider {
impl ExecutionProvider for WebGPU {
fn name(&self) -> &'static str {
"WebGpuExecutionProvider"
}

View File

@@ -4,56 +4,56 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebNNPowerPreference {
pub enum PowerPreference {
#[default]
Default,
HighPerformance,
LowPower
}
impl WebNNPowerPreference {
impl PowerPreference {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
WebNNPowerPreference::Default => "default",
WebNNPowerPreference::HighPerformance => "high-performance",
WebNNPowerPreference::LowPower => "low-power"
PowerPreference::Default => "default",
PowerPreference::HighPerformance => "high-performance",
PowerPreference::LowPower => "low-power"
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebNNDeviceType {
pub enum DeviceType {
CPU,
GPU,
NPU
}
impl WebNNDeviceType {
impl DeviceType {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
WebNNDeviceType::CPU => "cpu",
WebNNDeviceType::GPU => "gpu",
WebNNDeviceType::NPU => "npu"
DeviceType::CPU => "cpu",
DeviceType::GPU => "gpu",
DeviceType::NPU => "npu"
}
}
}
#[derive(Debug, Default, Clone)]
pub struct WebNNExecutionProvider {
pub struct WebNN {
options: ExecutionProviderOptions
}
impl WebNNExecutionProvider {
impl WebNN {
#[must_use]
pub fn with_device_type(mut self, device_type: WebNNDeviceType) -> Self {
pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
self.options.set("deviceType", device_type.as_str());
self
}
#[must_use]
pub fn with_power_preference(mut self, pref: WebNNPowerPreference) -> Self {
pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
self.options.set("powerPreference", pref.as_str());
self
}
@@ -65,9 +65,9 @@ impl WebNNExecutionProvider {
}
}
super::impl_ep!(arbitrary; WebNNExecutionProvider);
super::impl_ep!(arbitrary; WebNN);
impl ExecutionProvider for WebNNExecutionProvider {
impl ExecutionProvider for WebNN {
fn name(&self) -> &'static str {
"WebNNExecutionProvider"
}

View File

@@ -13,12 +13,12 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// disable the session intra-op threadpool to reduce contention:
/// ```no_run
/// # use core::num::NonZeroUsize;
/// # use ort::{execution_providers::xnnpack::XNNPACKExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?
/// .with_intra_op_spinning(false)?
/// .with_intra_threads(1)?
/// .with_execution_providers([XNNPACKExecutionProvider::default()
/// .with_execution_providers([ep::XNNPACK::default()
/// .with_intra_op_num_threads(NonZeroUsize::new(4).unwrap())
/// .build()])?
/// .commit_from_file("model.onnx")?;
@@ -26,22 +26,20 @@ use crate::{error::Result, session::builder::SessionBuilder};
/// # }
/// ```
#[derive(Debug, Default, Clone)]
pub struct XNNPACKExecutionProvider {
pub struct XNNPACK {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; XNNPACKExecutionProvider);
super::impl_ep!(arbitrary; XNNPACK);
impl XNNPACKExecutionProvider {
impl XNNPACK {
/// Configures the number of threads to use for XNNPACK's internal intra-op threadpool.
///
/// ```
/// # use core::num::NonZeroUsize;
/// # use ort::{execution_providers::xnnpack::XNNPACKExecutionProvider, session::Session};
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = XNNPACKExecutionProvider::default()
/// .with_intra_op_num_threads(NonZeroUsize::new(4).unwrap())
/// .build();
/// let ep = ep::XNNPACK::default().with_intra_op_num_threads(NonZeroUsize::new(4).unwrap()).build();
/// # Ok(())
/// # }
/// ```
@@ -52,7 +50,7 @@ impl XNNPACKExecutionProvider {
}
}
impl ExecutionProvider for XNNPACKExecutionProvider {
impl ExecutionProvider for XNNPACK {
fn name(&self) -> &'static str {
"XnnpackExecutionProvider"
}

View File

@@ -37,7 +37,7 @@ use crate::{
///
/// ```no_run
/// # use ort::{
/// # execution_providers::CUDAExecutionProvider,
/// # ep,
/// # io_binding::IoBinding,
/// # memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType},
/// # session::Session,
@@ -45,10 +45,10 @@ use crate::{
/// # };
/// # fn main() -> ort::Result<()> {
/// let mut text_encoder = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .with_execution_providers([ep::CUDA::default().build()])?
/// .commit_from_file("text_encoder.onnx")?;
/// let mut unet = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .with_execution_providers([ep::CUDA::default().build()])?
/// .commit_from_file("unet.onnx")?;
///
/// let text_condition = text_encoder

View File

@@ -3,6 +3,7 @@
#![allow(clippy::tabs_in_doc_comments, clippy::arc_with_non_send_sync)]
#![allow(clippy::macro_metavars_in_unsafe)]
#![warn(clippy::unwrap_used)]
#![deny(clippy::std_instead_of_alloc, clippy::std_instead_of_core)]
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
//! <div align=center>
@@ -28,8 +29,8 @@ pub mod adapter;
pub mod compiler;
pub mod editor;
pub mod environment;
pub mod ep;
pub mod error;
pub mod execution_providers;
pub mod io_binding;
pub mod logging;
pub mod memory;
@@ -48,6 +49,13 @@ pub mod api {
pub use super::{api as ort, compiler::compile_api as compile, editor::editor_api as editor};
}
#[deprecated = "import execution providers from `ort::ep` instead"]
#[doc(hidden)]
pub mod execution_providers {
#[deprecated = "import execution providers from `ort::ep` instead"]
pub use super::ep::*;
}
use alloc::{borrow::ToOwned, boxed::Box, string::String};
use core::{
ffi::{CStr, c_char},

View File

@@ -36,7 +36,8 @@ pub trait Operator: Send {
/// Returns the name of the operator.
fn name(&self) -> &str;
/// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`.
/// Returns the internal name of the execution provider this operator runs on, e.g. `"CUDAExecutionProvider"` (see
/// [`ExecutionProvider::name`](crate::ep::ExecutionProvider::name)).
///
/// If the returned type is not `None`, and the execution provider used by the session matches this operator's
/// EP type, the value will not be copied to the CPU and you may use functions like [`Tensor::data_ptr`] to

View File

@@ -6,7 +6,6 @@ use core::{
ffi::c_void,
marker::PhantomData,
mem::replace,
ops::Deref,
ptr::{self, NonNull}
};
#[cfg(feature = "std")]
@@ -23,8 +22,8 @@ use crate::error::{Error, ErrorCode};
use crate::util::OsCharArray;
use crate::{
AsPointer,
ep::apply_execution_providers,
error::Result,
execution_providers::apply_execution_providers,
memory::Allocator,
ortsys,
session::{InMemorySession, Input, Output, Session, SharedSessionInner, dangerous}
@@ -136,7 +135,7 @@ impl SessionBuilder {
}
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
fn commit_from_file_inner(mut self, model_path: &<OsCharArray as Deref>::Target) -> Result<Session> {
fn commit_from_file_inner(mut self, model_path: &<OsCharArray as core::ops::Deref>::Target) -> Result<Session> {
self.pre_commit()?;
let session_ptr = if let Some(prepacked_weights) = self.prepacked_weights.as_ref() {

View File

@@ -13,8 +13,8 @@ use crate::util::path_to_os_char;
use crate::{
AsPointer, Error, ErrorCode,
environment::{self, ThreadManager},
ep::{ExecutionProviderDispatch, apply_execution_providers},
error::Result,
execution_providers::{ExecutionProviderDispatch, apply_execution_providers},
logging::{LogLevel, LoggerFunction},
memory::MemoryInfo,
operator::OperatorDomain,
@@ -34,8 +34,8 @@ impl SessionBuilder {
/// ## Notes
///
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
/// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by
/// providing a `Vec` of [`ExecutionProviderDispatch`]es.
/// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec`
/// of [`ExecutionProviderDispatch`]es.
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result<Self> {
apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?;
Ok(self)

View File

@@ -36,8 +36,8 @@ pub(crate) use self::{
/// absolute path, then any subsequent requests to load a library called `foo` will use the `libfoo.so` we already
/// loaded, instead of searching the system for a `foo` library.
///
/// See also [`crate::execution_providers::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all
/// required dependencies of the [CUDA execution provider](crate::execution_providers::CUDAExecutionProvider).
/// See also [`crate::ep::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all required dependencies
/// of the [CUDA execution provider](crate::ep::CUDA).
///
/// ```
/// use std::env;

View File

@@ -6,7 +6,7 @@ use core::ops::{Deref, DerefMut};
use super::DefiniteTensorValueTypeMarker;
use crate::{
Error, OnceLock, Result, execution_providers as ep,
Error, OnceLock, Result, ep,
io_binding::IoBinding,
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
session::{NoSelectedOutputs, RunOptions, Session, builder::GraphOptimizationLevel},
@@ -36,20 +36,20 @@ static IDENTITY_RUN_OPTIONS: OnceLock<RunOptions<NoSelectedOutputs>> = OnceLock:
fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result<ep::ExecutionProviderDispatch> {
Ok(match device {
AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(),
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default()
AllocationDevice::CPU => ep::CPU::default().with_arena_allocator(false).build(),
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDA::default()
.with_device_id(device_id)
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
.with_conv_max_workspace(false)
.with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default)
.with_conv_algorithm_search(ep::cuda::ConvAlgorithmSearch::Default)
.build(),
AllocationDevice::DIRECTML => ep::DirectMLExecutionProvider::default().with_device_id(device_id).build(),
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default()
AllocationDevice::DIRECTML => ep::DirectML::default().with_device_id(device_id).build(),
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANN::default()
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
.with_cann_graph(false)
.with_device_id(device_id)
.build(),
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default()
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINO::default()
.with_num_threads(1)
.with_device_type(if device == AllocationDevice::OPENVINO_CPU {
"CPU".to_string()
@@ -57,7 +57,7 @@ fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result<ep::Executi
format!("GPU.{device_id}")
})
.build(),
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCmExecutionProvider::default()
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCm::default()
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
.with_hip_graph(false)
.with_exhaustive_conv_search(false)
@@ -388,13 +388,13 @@ mod tests {
#[cfg(feature = "cuda")]
fn test_copy_into_cuda() -> crate::Result<()> {
use crate::{
execution_providers::CUDAExecutionProvider,
ep,
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
session::Session
};
let dummy_session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_execution_providers([ep::CUDA::default().build()])?
.commit_from_file("tests/data/upsample.ort")?;
let allocator = Allocator::new(&dummy_session, MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?)?;
@@ -412,13 +412,13 @@ mod tests {
#[cfg(feature = "cuda")]
fn test_copy_into_async_cuda() -> crate::Result<()> {
use crate::{
execution_providers::CUDAExecutionProvider,
ep,
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
session::Session
};
let dummy_session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_execution_providers([ep::CUDA::default().build()])?
.commit_from_file("tests/data/upsample.ort")?;
let allocator = Allocator::new(&dummy_session, MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?)?;

View File

@@ -213,7 +213,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> {
/// ```
///
/// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
/// returned. See [`ndarray::ArrayRef::as_standard_layout`] to convert an array to a contiguous layout.
pub fn from_array_view(input: impl TensorArrayData<T> + 'a) -> Result<TensorRef<'a, T>> {
let (shape, data, guard) = input.ref_parts()?;
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
@@ -252,7 +252,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
/// ```
///
/// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
/// returned. See [`ndarray::ArrayRef::as_standard_layout`] to convert an array to a contiguous layout.
pub fn from_array_view_mut(mut input: impl TensorArrayDataMut<T>) -> Result<TensorRefMut<'a, T>> {
let (shape, data, guard) = input.ref_parts_mut()?;
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {

View File

@@ -1,6 +1,6 @@
use ort::{
adapter::Adapter,
execution_providers::CPUExecutionProvider,
ep,
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
operator::{
Operator, OperatorDomain,
@@ -77,7 +77,7 @@ impl Operator for CustomOpTwo {
}
fn main() -> ort::Result<()> {
ort::init().with_execution_providers([CPUExecutionProvider::default().build()]).commit();
ort::init().with_execution_providers([ep::CPU::default().build()]).commit();
let mut session = Session::builder()?
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?