mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: shorten execution_providers
frankly, working on documentation has made me tired of typing out `execution_providers` and `ExecutionProvider` all the time
This commit is contained in:
@@ -12,7 +12,7 @@ To configure an `Environment`, you instead use the `ort::init` function, which r
|
||||
|
||||
```rust
|
||||
ort::init()
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.with_execution_providers([ep::CUDA::default().build()])
|
||||
.commit();
|
||||
```
|
||||
|
||||
@@ -156,7 +156,7 @@ let l = outputs["latents"].try_extract_array::<f32>()?;
|
||||
```
|
||||
|
||||
## Execution providers
|
||||
Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/execution_providers/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information.
|
||||
Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/ep/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information.
|
||||
|
||||
```diff
|
||||
-// v1.x
|
||||
@@ -165,7 +165,7 @@ Execution provider structs with public fields have been replaced with builder pa
|
||||
-}))?;
|
||||
+// v2
|
||||
+builder = builder.with_execution_providers([
|
||||
+ DirectMLExecutionProvider::default()
|
||||
+ ep::DirectML::default()
|
||||
+ .with_device_id(1)
|
||||
+ .build()
|
||||
+])?;
|
||||
|
||||
@@ -52,11 +52,11 @@ ONNX Runtime must be compiled from source with support for each execution provid
|
||||
In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session:
|
||||
|
||||
```rust
|
||||
use ort::{execution_providers::CUDAExecutionProvider, session::Session};
|
||||
use ort::{ep::CUDA, session::Session};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_execution_providers([CUDA::default().build()])?
|
||||
.commit_from_file("model.onnx")?;
|
||||
|
||||
Ok(())
|
||||
@@ -66,21 +66,18 @@ fn main() -> anyhow::Result<()> {
|
||||
You can, of course, specify multiple execution providers. `ort` will register all EPs specified, in order. If an EP does not support a certain operator in a graph, it will fall back to the next successfully registered EP, or to the CPU if all else fails.
|
||||
|
||||
```rust
|
||||
use ort::{
|
||||
execution_providers::{CoreMLExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, TensorRTExecutionProvider},
|
||||
session::Session
|
||||
};
|
||||
use ort::{ep, session::Session};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([
|
||||
// Prefer TensorRT over CUDA.
|
||||
TensorRTExecutionProvider::default().build(),
|
||||
CUDAExecutionProvider::default().build(),
|
||||
ep::TensorRT::default().build(),
|
||||
ep::CUDA::default().build(),
|
||||
// Use DirectML on Windows if NVIDIA EPs are not available
|
||||
DirectMLExecutionProvider::default().build(),
|
||||
ep::DirectML::default().build(),
|
||||
// Or use ANE on Apple platforms
|
||||
CoreMLExecutionProvider::default().build()
|
||||
ep::CoreML::default().build()
|
||||
])?
|
||||
.commit_from_file("model.onnx")?;
|
||||
|
||||
@@ -89,18 +86,18 @@ fn main() -> anyhow::Result<()> {
|
||||
```
|
||||
|
||||
## Configuring EPs
|
||||
EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/execution_providers/index.html#reexports) for the EP structs for more information on which options are supported and what they do.
|
||||
EPs have configuration options to control behavior or increase performance. Each execution provider struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/ep/index.html#reexports) for the EP structs for more information on which options are supported and what they do.
|
||||
|
||||
```rust
|
||||
use ort::{execution_providers::{coreml::CoreMLComputeUnits, CoreMLExecutionProvider}, session::Session};
|
||||
use ort::{ep, session::Session};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([
|
||||
CoreMLExecutionProvider::default()
|
||||
ep::CoreML::default()
|
||||
// this model uses control flow operators, so enable CoreML on subgraphs too
|
||||
.with_subgraphs()
|
||||
.with_compute_units(CoreMLComputeUnits::CPUAndNeuralEngine)
|
||||
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
|
||||
.build()
|
||||
])?
|
||||
.commit_from_file("model.onnx")?;
|
||||
@@ -114,12 +111,12 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`:
|
||||
```rust
|
||||
use ort::{execution_providers::CoreMLExecutionProvider, session::Session};
|
||||
use ort::{ep, session::Session};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([
|
||||
CUDAExecutionProvider::default().build().error_on_failure()
|
||||
ep::CUDA::default().build().error_on_failure()
|
||||
])?
|
||||
.commit_from_file("model.onnx")?;
|
||||
|
||||
@@ -131,14 +128,14 @@ If you require more complex error handling, you can also manually register execu
|
||||
|
||||
```rust
|
||||
use ort::{
|
||||
execution_providers::{CUDAExecutionProvider, ExecutionProvider},
|
||||
ep::{self, ExecutionProvider},
|
||||
session::Session
|
||||
};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let builder = Session::builder()?;
|
||||
|
||||
let cuda = CUDAExecutionProvider::default();
|
||||
let cuda = ep::CUDA::default();
|
||||
if cuda.register(&builder).is_err() {
|
||||
eprintln!("Failed to register CUDA!");
|
||||
std::process::exit(1);
|
||||
@@ -154,14 +151,14 @@ You can also check whether ONNX Runtime is even compiled with support for the ex
|
||||
|
||||
```rust
|
||||
use ort::{
|
||||
execution_providers::{CoreMLExecutionProvider, ExecutionProvider},
|
||||
ep::{self, ExecutionProvider},
|
||||
session::Session
|
||||
};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let builder = Session::builder()?;
|
||||
|
||||
let coreml = CoreMLExecutionProvider::default();
|
||||
let coreml = ep::CoreML::default();
|
||||
if !coreml.is_available() {
|
||||
eprintln!("Please compile ONNX Runtime with CoreML!");
|
||||
std::process::exit(1);
|
||||
@@ -180,11 +177,11 @@ fn main() -> anyhow::Result<()> {
|
||||
You can configure `ort` to attempt to register a list of execution providers for all sessions created in an environment.
|
||||
|
||||
```rust
|
||||
use ort::{execution_providers::CUDAExecutionProvider, session::Session};
|
||||
use ort::{ep, session::Session};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
ort::init()
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.with_execution_providers([ep::CUDA::default().build()])
|
||||
.commit();
|
||||
|
||||
let session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
|
||||
@@ -80,10 +80,10 @@ Here is a more complete example of the I/O binding API in a scenario where I/O p
|
||||
|
||||
```rs
|
||||
let mut text_encoder = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_execution_providers([ep::CUDA::default().build()])?
|
||||
.commit_from_file("text_encoder.onnx")?;
|
||||
let mut unet = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_execution_providers([ep::CUDA::default().build()])?
|
||||
.commit_from_file("unet.onnx")?;
|
||||
|
||||
let text_condition = {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#[allow(unused)]
|
||||
use ort::execution_providers::*;
|
||||
use ort::ep::*;
|
||||
|
||||
pub fn init() -> ort::Result<()> {
|
||||
#[cfg(feature = "backend-candle")]
|
||||
@@ -11,41 +11,41 @@ pub fn init() -> ort::Result<()> {
|
||||
ort::init()
|
||||
.with_execution_providers([
|
||||
#[cfg(feature = "tensorrt")]
|
||||
TensorRTExecutionProvider::default().build(),
|
||||
TensorRT::default().build(),
|
||||
#[cfg(feature = "cuda")]
|
||||
CUDAExecutionProvider::default().build(),
|
||||
CUDA::default().build(),
|
||||
#[cfg(feature = "onednn")]
|
||||
OneDNNExecutionProvider::default().build(),
|
||||
OneDNN::default().build(),
|
||||
#[cfg(feature = "acl")]
|
||||
ACLExecutionProvider::default().build(),
|
||||
ACL::default().build(),
|
||||
#[cfg(feature = "openvino")]
|
||||
OpenVINOExecutionProvider::default().build(),
|
||||
OpenVINO::default().build(),
|
||||
#[cfg(feature = "coreml")]
|
||||
CoreMLExecutionProvider::default().build(),
|
||||
CoreML::default().build(),
|
||||
#[cfg(feature = "rocm")]
|
||||
ROCmExecutionProvider::default().build(),
|
||||
ROCm::default().build(),
|
||||
#[cfg(feature = "cann")]
|
||||
CANNExecutionProvider::default().build(),
|
||||
CANN::default().build(),
|
||||
#[cfg(feature = "directml")]
|
||||
DirectMLExecutionProvider::default().build(),
|
||||
DirectML::default().build(),
|
||||
#[cfg(feature = "tvm")]
|
||||
TVMExecutionProvider::default().build(),
|
||||
TVM::default().build(),
|
||||
#[cfg(feature = "nnapi")]
|
||||
NNAPIExecutionProvider::default().build(),
|
||||
NNAPI::default().build(),
|
||||
#[cfg(feature = "qnn")]
|
||||
QNNExecutionProvider::default().build(),
|
||||
QNN::default().build(),
|
||||
#[cfg(feature = "xnnpack")]
|
||||
XNNPACKExecutionProvider::default().build(),
|
||||
XNNPACK::default().build(),
|
||||
#[cfg(feature = "armnn")]
|
||||
ArmNNExecutionProvider::default().build(),
|
||||
ArmNN::default().build(),
|
||||
#[cfg(feature = "migraphx")]
|
||||
MIGraphXExecutionProvider::default().build(),
|
||||
MIGraphX::default().build(),
|
||||
#[cfg(feature = "vitis")]
|
||||
VitisAIExecutionProvider::default().build(),
|
||||
VitisAI::default().build(),
|
||||
#[cfg(feature = "rknpu")]
|
||||
RKNPUExecutionProvider::default().build(),
|
||||
RKNPU::default().build(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
WebGPUExecutionProvider::default().build()
|
||||
WebGPU::default().build()
|
||||
])
|
||||
.commit();
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use cudarc::driver::{CudaDevice, DevicePtr};
|
||||
use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType};
|
||||
use ndarray::Array;
|
||||
use ort::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
ep,
|
||||
memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType},
|
||||
session::Session,
|
||||
tensor::Shape,
|
||||
@@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
|
||||
#[rustfmt::skip]
|
||||
ort::init()
|
||||
.with_execution_providers([
|
||||
CUDAExecutionProvider::default()
|
||||
ep::CUDA::default()
|
||||
.build()
|
||||
// exit the program with an error if the CUDA EP fails to register
|
||||
.error_on_failure()
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{
|
||||
|
||||
use kdam::BarExt;
|
||||
use ort::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
ep,
|
||||
memory::Allocator,
|
||||
session::{Session, builder::SessionBuilder},
|
||||
training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments},
|
||||
@@ -56,7 +56,7 @@ fn main() -> ort::Result<()> {
|
||||
common::init()?;
|
||||
|
||||
let trainer = Trainer::new_from_artifacts(
|
||||
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
|
||||
SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?,
|
||||
Allocator::default(),
|
||||
"tools/train-data/mini-clm",
|
||||
None
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{
|
||||
|
||||
use kdam::BarExt;
|
||||
use ort::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
ep,
|
||||
memory::Allocator,
|
||||
session::{Session, builder::SessionBuilder},
|
||||
training::{Checkpoint, Trainer},
|
||||
@@ -38,7 +38,7 @@ fn main() -> ort::Result<()> {
|
||||
let _ = kdam::term::hide_cursor();
|
||||
|
||||
let trainer = Trainer::new(
|
||||
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
|
||||
SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?,
|
||||
Allocator::default(),
|
||||
Checkpoint::load("tools/train-data/mini-clm/checkpoint")?,
|
||||
"tools/train-data/mini-clm/training_model.onnx",
|
||||
|
||||
@@ -57,8 +57,8 @@ pub extern "C" fn detect_objects(ptr: *const u8, width: u32, height: u32) {
|
||||
|
||||
let use_webgpu = true; // TODO: Make `use_webgpu` a parameter of `detect_objects`? Or say in README to change it here.
|
||||
if use_webgpu {
|
||||
use ort::execution_providers::ExecutionProvider;
|
||||
let ep = ort::execution_providers::WebGPUExecutionProvider::default();
|
||||
use ort::ep::ExecutionProvider;
|
||||
let ep = ort::ep::WebGPUExecutionProvider::default();
|
||||
if ep.is_available().expect("Cannot check for availability of WebGPU ep.") {
|
||||
ep.register(&mut builder).expect("Cannot register WebGPU ep.");
|
||||
} else {
|
||||
|
||||
@@ -63,14 +63,14 @@ impl Adapter {
|
||||
/// ```
|
||||
/// # use ort::{
|
||||
/// # adapter::Adapter,
|
||||
/// # execution_providers::CUDAExecutionProvider,
|
||||
/// # ep,
|
||||
/// # memory::DeviceType,
|
||||
/// # session::{run_options::RunOptions, Session},
|
||||
/// # value::Tensor
|
||||
/// # };
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut model = Session::builder()?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
/// .with_execution_providers([ep::CUDA::default().build()])?
|
||||
/// .commit_from_file("tests/data/lora_model.onnx")?;
|
||||
///
|
||||
/// let allocator = model.allocator();
|
||||
@@ -108,14 +108,14 @@ impl Adapter {
|
||||
/// ```
|
||||
/// # use ort::{
|
||||
/// # adapter::Adapter,
|
||||
/// # execution_providers::CUDAExecutionProvider,
|
||||
/// # ep,
|
||||
/// # memory::DeviceType,
|
||||
/// # session::{run_options::RunOptions, Session},
|
||||
/// # value::Tensor
|
||||
/// # };
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut model = Session::builder()?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
/// .with_execution_providers([ep::CUDA::default().build()])?
|
||||
/// .commit_from_file("tests/data/lora_model.onnx")?;
|
||||
///
|
||||
/// let bytes = std::fs::read("tests/data/adapter.orl").unwrap();
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
//!
|
||||
//! Environments can be configured via [`ort::init`](init):
|
||||
//! ```
|
||||
//! # use ort::execution_providers::CUDAExecutionProvider;
|
||||
//! # use ort::ep;
|
||||
//! # fn main() -> ort::Result<()> {
|
||||
//! ort::init().with_execution_providers([CUDAExecutionProvider::default().build()]).commit();
|
||||
//! ort::init().with_execution_providers([ep::CUDA::default().build()]).commit();
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@@ -25,8 +25,8 @@ use smallvec::SmallVec;
|
||||
|
||||
use crate::{
|
||||
AsPointer,
|
||||
ep::ExecutionProviderDispatch,
|
||||
error::Result,
|
||||
execution_providers::ExecutionProviderDispatch,
|
||||
logging::{LogLevel, LoggerFunction},
|
||||
ortsys,
|
||||
util::{Mutex, OnceLock, STACK_EXECUTION_PROVIDERS, with_cstr}
|
||||
@@ -430,9 +430,9 @@ impl EnvironmentBuilder {
|
||||
/// Creates an ONNX Runtime environment.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::execution_providers::CUDAExecutionProvider;
|
||||
/// # use ort::ep;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// ort::init().with_execution_providers([CUDAExecutionProvider::default().build()]).commit();
|
||||
/// ort::init().with_execution_providers([ep::CUDA::default().build()]).commit();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -456,11 +456,11 @@ pub fn init() -> EnvironmentBuilder {
|
||||
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ort::execution_providers::CUDAExecutionProvider;
|
||||
/// # use ort::ep;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
|
||||
/// ort::init_from(lib_path.join("onnxruntime.dll"))?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
/// .with_execution_providers([ep::CUDA::default().build()])
|
||||
/// .commit();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
|
||||
@@ -4,20 +4,20 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// [Arm Compute Library execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/ACL-ExecutionProvider.html)
|
||||
/// for ARM platforms.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ACLExecutionProvider {
|
||||
pub struct ACL {
|
||||
fast_math: bool
|
||||
}
|
||||
|
||||
super::impl_ep!(ACLExecutionProvider);
|
||||
super::impl_ep!(ACL);
|
||||
|
||||
impl ACLExecutionProvider {
|
||||
impl ACL {
|
||||
/// Enable/disable ACL's fast math mode. Enabling can improve performance at the cost of some accuracy for
|
||||
/// `MatMul`/`Conv` nodes.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::acl::ACLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = ACLExecutionProvider::default().with_fast_math(true).build();
|
||||
/// let ep = ep::ACL::default().with_fast_math(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -28,7 +28,7 @@ impl ACLExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for ACLExecutionProvider {
|
||||
impl ExecutionProvider for ACL {
|
||||
fn name(&self) -> &'static str {
|
||||
"ACLExecutionProvider"
|
||||
}
|
||||
@@ -4,19 +4,19 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// [Arm NN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/ArmNN-ExecutionProvider.html)
|
||||
/// for ARM platforms.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ArmNNExecutionProvider {
|
||||
pub struct ArmNN {
|
||||
use_arena: bool
|
||||
}
|
||||
|
||||
super::impl_ep!(ArmNNExecutionProvider);
|
||||
super::impl_ep!(ArmNN);
|
||||
|
||||
impl ArmNNExecutionProvider {
|
||||
impl ArmNN {
|
||||
/// Enable/disable the usage of the arena allocator.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::armnn::ArmNNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = ArmNNExecutionProvider::default().with_arena_allocator(true).build();
|
||||
/// let ep = ep::ArmNN::default().with_arena_allocator(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -27,7 +27,7 @@ impl ArmNNExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for ArmNNExecutionProvider {
|
||||
impl ExecutionProvider for ArmNN {
|
||||
fn name(&self) -> &'static str {
|
||||
"ArmNNExecutionProvider"
|
||||
}
|
||||
@@ -35,12 +35,12 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
///
|
||||
/// To use this model in `ort`:
|
||||
/// ```no_run
|
||||
/// # use ort::{execution_providers::azure::AzureExecutionProvider, session::Session, value::Tensor};
|
||||
/// # use ort::{ep, session::Session, value::Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut session = Session::builder()?
|
||||
/// // note: session must be initialized with `onnxruntime-extensions`
|
||||
/// .with_extensions()?
|
||||
/// .with_execution_providers([AzureExecutionProvider::default().build()])?
|
||||
/// .with_execution_providers([ep::Azure::default().build()])?
|
||||
/// .commit_from_file("azure_chat.onnx")?;
|
||||
///
|
||||
/// let auth_token = Tensor::from_string_array(([1], &*vec!["..."]))?;
|
||||
@@ -65,13 +65,13 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct AzureExecutionProvider {
|
||||
pub struct Azure {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; AzureExecutionProvider);
|
||||
super::impl_ep!(arbitrary; Azure);
|
||||
|
||||
impl ExecutionProvider for AzureExecutionProvider {
|
||||
impl ExecutionProvider for Azure {
|
||||
fn name(&self) -> &'static str {
|
||||
"AzureExecutionProvider"
|
||||
}
|
||||
@@ -5,7 +5,7 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum CANNPrecisionMode {
|
||||
pub enum PrecisionMode {
|
||||
/// Convert to float32 first according to operator implementation
|
||||
ForceFP32,
|
||||
/// Convert to float16 when float16 and float32 are both supported
|
||||
@@ -20,7 +20,7 @@ pub enum CANNPrecisionMode {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum CANNImplementationMode {
|
||||
pub enum ImplementationMode {
|
||||
/// Prefer high precision, potentially at the cost of some performance.
|
||||
HighPrecision,
|
||||
/// Prefer high performance, potentially with lower accuracy.
|
||||
@@ -30,19 +30,19 @@ pub enum CANNImplementationMode {
|
||||
/// [CANN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html)
|
||||
/// for hardware acceleration using Huawei Ascend AI devices.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CANNExecutionProvider {
|
||||
pub struct CANN {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; CANNExecutionProvider);
|
||||
super::impl_ep!(arbitrary; CANN);
|
||||
|
||||
impl CANNExecutionProvider {
|
||||
impl CANN {
|
||||
/// Configures which device the EP should use.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_device_id(0).build();
|
||||
/// let ep = ep::CANN::default().with_device_id(0).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -56,9 +56,9 @@ impl CANNExecutionProvider {
|
||||
/// provider’s arena; the total device memory usage may be higher.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
|
||||
/// let ep = ep::CANN::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -71,9 +71,9 @@ impl CANNExecutionProvider {
|
||||
/// Configure the strategy for extending the device's memory arena.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::{cann::CANNExecutionProvider, ArenaExtendStrategy}, session::Session};
|
||||
/// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default()
|
||||
/// let ep = ep::CANN::default()
|
||||
/// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested)
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
@@ -95,9 +95,9 @@ impl CANNExecutionProvider {
|
||||
/// is `true`. If `false`, it will fall back to the single-operator inference engine.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_cann_graph(true).build();
|
||||
/// let ep = ep::CANN::default().with_cann_graph(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -110,9 +110,9 @@ impl CANNExecutionProvider {
|
||||
/// Configure whether to dump the subgraph into ONNX format for analysis of subgraph segmentation.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_dump_graphs(true).build();
|
||||
/// let ep = ep::CANN::default().with_dump_graphs(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -125,9 +125,9 @@ impl CANNExecutionProvider {
|
||||
/// Configure whether to dump the offline model to an `.om` file.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_dump_om_model(true).build();
|
||||
/// let ep = ep::CANN::default().with_dump_om_model(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -137,25 +137,25 @@ impl CANNExecutionProvider {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure the precision mode; see [`CANNPrecisionMode`].
|
||||
/// Configure the precision mode; see [`PrecisionMode`].
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::{CANNExecutionProvider, CANNPrecisionMode}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_precision_mode(CANNPrecisionMode::ForceFP16).build();
|
||||
/// let ep = ep::CANN::default().with_precision_mode(ep::cann::PrecisionMode::ForceFP16).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn with_precision_mode(mut self, mode: CANNPrecisionMode) -> Self {
|
||||
pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
|
||||
self.options.set(
|
||||
"precision_mode",
|
||||
match mode {
|
||||
CANNPrecisionMode::ForceFP32 => "force_fp32",
|
||||
CANNPrecisionMode::ForceFP16 => "force_fp16",
|
||||
CANNPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
|
||||
CANNPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
|
||||
CANNPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
|
||||
PrecisionMode::ForceFP32 => "force_fp32",
|
||||
PrecisionMode::ForceFP16 => "force_fp16",
|
||||
PrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
|
||||
PrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
|
||||
PrecisionMode::AllowMixedPrecision => "allow_mix_precision"
|
||||
}
|
||||
);
|
||||
self
|
||||
@@ -165,33 +165,33 @@ impl CANNExecutionProvider {
|
||||
/// high-performance implementations.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::{CANNExecutionProvider, CANNImplementationMode}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default()
|
||||
/// .with_implementation_mode(CANNImplementationMode::HighPerformance)
|
||||
/// let ep = ep::CANN::default()
|
||||
/// .with_implementation_mode(ep::cann::ImplementationMode::HighPerformance)
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn with_implementation_mode(mut self, mode: CANNImplementationMode) -> Self {
|
||||
pub fn with_implementation_mode(mut self, mode: ImplementationMode) -> Self {
|
||||
self.options.set(
|
||||
"op_select_impl_mode",
|
||||
match mode {
|
||||
CANNImplementationMode::HighPrecision => "high_precision",
|
||||
CANNImplementationMode::HighPerformance => "high_performance"
|
||||
ImplementationMode::HighPrecision => "high_precision",
|
||||
ImplementationMode::HighPerformance => "high_performance"
|
||||
}
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure the list of operators which use the mode specified by
|
||||
/// [`CANNExecutionProvider::with_implementation_mode`].
|
||||
/// [`CANN::with_implementation_mode`].
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CANNExecutionProvider::default().with_implementation_mode_oplist("LayerNorm,Gelu").build();
|
||||
/// let ep = ep::CANN::default().with_implementation_mode_oplist("LayerNorm,Gelu").build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -202,7 +202,7 @@ impl CANNExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for CANNExecutionProvider {
|
||||
impl ExecutionProvider for CANN {
|
||||
fn name(&self) -> &'static str {
|
||||
"CANNExecutionProvider"
|
||||
}
|
||||
@@ -4,7 +4,7 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CoreMLSpecializationStrategy {
|
||||
pub enum SpecializationStrategy {
|
||||
/// The strategy that should work well for most applications.
|
||||
Default,
|
||||
/// Prefer the prediction latency at the potential cost of specialization time, memory footprint, and the disk space
|
||||
@@ -12,7 +12,7 @@ pub enum CoreMLSpecializationStrategy {
|
||||
FastPrediction
|
||||
}
|
||||
|
||||
impl CoreMLSpecializationStrategy {
|
||||
impl SpecializationStrategy {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Default => "Default",
|
||||
@@ -22,7 +22,7 @@ impl CoreMLSpecializationStrategy {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CoreMLComputeUnits {
|
||||
pub enum ComputeUnits {
|
||||
/// Enable CoreML EP for all compatible Apple devices.
|
||||
All,
|
||||
/// Enable CoreML EP for Apple devices with a compatible Neural Engine (ANE).
|
||||
@@ -33,7 +33,7 @@ pub enum CoreMLComputeUnits {
|
||||
CPUOnly
|
||||
}
|
||||
|
||||
impl CoreMLComputeUnits {
|
||||
impl ComputeUnits {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::All => "ALL",
|
||||
@@ -45,14 +45,14 @@ impl CoreMLComputeUnits {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CoreMLModelFormat {
|
||||
pub enum ModelFormat {
|
||||
/// Requires Core ML 5 or later (iOS 15+ or macOS 12+).
|
||||
MLProgram,
|
||||
/// Default; requires Core ML 3 or later (iOS 13+ or macOS 10.15+).
|
||||
NeuralNetwork
|
||||
}
|
||||
|
||||
impl CoreMLModelFormat {
|
||||
impl ModelFormat {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::MLProgram => "MLProgram",
|
||||
@@ -64,20 +64,20 @@ impl CoreMLModelFormat {
|
||||
/// [CoreML execution provider](https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html) for hardware
|
||||
/// acceleration on Apple devices.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct CoreMLExecutionProvider {
|
||||
pub struct CoreML {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; CoreMLExecutionProvider);
|
||||
super::impl_ep!(arbitrary; CoreML);
|
||||
|
||||
impl CoreMLExecutionProvider {
|
||||
impl CoreML {
|
||||
/// Enable CoreML EP to run on a subgraph in the body of a control flow operator (i.e. a `Loop`, `Scan` or `If`
|
||||
/// operator).
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_subgraphs(true).build();
|
||||
/// let ep = ep::CoreML::default().with_subgraphs(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -91,9 +91,9 @@ impl CoreMLExecutionProvider {
|
||||
/// allow inputs with dynamic shapes, however performance may be negatively impacted by inputs with dynamic shapes.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_static_input_shapes(true).build();
|
||||
/// let ep = ep::CoreML::default().with_static_input_shapes(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -105,19 +105,19 @@ impl CoreMLExecutionProvider {
|
||||
|
||||
/// Configures the format of the CoreML model created by the EP.
|
||||
///
|
||||
/// The default format, [NeuralNetwork](`CoreMLModelFormat::NeuralNetwork`), has better compatibility with older
|
||||
/// versions of macOS/iOS. The newer [MLProgram](`CoreMLModelFormat::MLProgram`) format supports more operators,
|
||||
/// The default format, [NeuralNetwork](`ModelFormat::NeuralNetwork`), has better compatibility with older
|
||||
/// versions of macOS/iOS. The newer [MLProgram](`ModelFormat::MLProgram`) format supports more operators,
|
||||
/// and may be more performant.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLModelFormat}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_model_format(CoreMLModelFormat::MLProgram).build();
|
||||
/// let ep = ep::CoreML::default().with_model_format(ep::coreml::ModelFormat::MLProgram).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn with_model_format(mut self, model_format: CoreMLModelFormat) -> Self {
|
||||
pub fn with_model_format(mut self, model_format: ModelFormat) -> Self {
|
||||
self.options.set("ModelFormat", model_format.as_str());
|
||||
self
|
||||
}
|
||||
@@ -129,14 +129,16 @@ impl CoreMLExecutionProvider {
|
||||
/// model for faster prediction, at the potential cost of session load time and memory footprint.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLSpecializationStrategy}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_specialization_strategy(CoreMLSpecializationStrategy::FastPrediction).build();
|
||||
/// let ep = ep::CoreML::default()
|
||||
/// .with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn with_specialization_strategy(mut self, strategy: CoreMLSpecializationStrategy) -> Self {
|
||||
pub fn with_specialization_strategy(mut self, strategy: SpecializationStrategy) -> Self {
|
||||
self.options.set("SpecializationStrategy", strategy.as_str());
|
||||
self
|
||||
}
|
||||
@@ -144,16 +146,16 @@ impl CoreMLExecutionProvider {
|
||||
/// Configures what hardware can be used by CoreML for acceleration.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLComputeUnits}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default()
|
||||
/// .with_compute_units(CoreMLComputeUnits::CPUAndNeuralEngine)
|
||||
/// let ep = ep::CoreML::default()
|
||||
/// .with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn with_compute_units(mut self, units: CoreMLComputeUnits) -> Self {
|
||||
pub fn with_compute_units(mut self, units: ComputeUnits) -> Self {
|
||||
self.options.set("MLComputeUnits", units.as_str());
|
||||
self
|
||||
}
|
||||
@@ -162,9 +164,9 @@ impl CoreMLExecutionProvider {
|
||||
/// for debugging unexpected performance with CoreML.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_profile_compute_plan(true).build();
|
||||
/// let ep = ep::CoreML::default().with_profile_compute_plan(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -177,9 +179,9 @@ impl CoreMLExecutionProvider {
|
||||
/// Configures whether to allow low-precision (fp16) accumulation on GPU.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_low_precision_accumulation_on_gpu(true).build();
|
||||
/// let ep = ep::CoreML::default().with_low_precision_accumulation_on_gpu(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -195,9 +197,9 @@ impl CoreMLExecutionProvider {
|
||||
/// session. Setting this option allows the compiled model to be reused across session loads.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CoreMLExecutionProvider::default().with_model_cache_dir("/path/to/cache").build();
|
||||
/// let ep = ep::CoreML::default().with_model_cache_dir("/path/to/cache").build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -235,7 +237,7 @@ impl CoreMLExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for CoreMLExecutionProvider {
|
||||
impl ExecutionProvider for CoreML {
|
||||
fn name(&self) -> &'static str {
|
||||
"CoreMLExecutionProvider"
|
||||
}
|
||||
@@ -3,19 +3,19 @@ use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};
|
||||
|
||||
/// The default CPU execution provider, powered by MLAS.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct CPUExecutionProvider {
|
||||
pub struct CPU {
|
||||
use_arena: bool
|
||||
}
|
||||
|
||||
super::impl_ep!(CPUExecutionProvider);
|
||||
super::impl_ep!(CPU);
|
||||
|
||||
impl CPUExecutionProvider {
|
||||
impl CPU {
|
||||
/// Enable/disable the usage of the arena allocator.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cpu::CPUExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CPUExecutionProvider::default().with_arena_allocator(true).build();
|
||||
/// let ep = ep::CPU::default().with_arena_allocator(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -26,7 +26,7 @@ impl CPUExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for CPUExecutionProvider {
|
||||
impl ExecutionProvider for CPU {
|
||||
fn name(&self) -> &'static str {
|
||||
"CPUExecutionProvider"
|
||||
}
|
||||
@@ -7,9 +7,9 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
// https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/contrib_ops/cpu/bert/attention_common.h#L160-L171
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(transparent)]
|
||||
pub struct CUDAAttentionBackend(u32);
|
||||
pub struct AttentionBackend(u32);
|
||||
|
||||
impl CUDAAttentionBackend {
|
||||
impl AttentionBackend {
|
||||
pub const FLASH_ATTENTION: Self = Self(1 << 0);
|
||||
pub const EFFICIENT_ATTENTION: Self = Self(1 << 1);
|
||||
pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2);
|
||||
@@ -23,7 +23,7 @@ impl CUDAAttentionBackend {
|
||||
pub const LEAN_ATTENTION: Self = Self(1 << 8);
|
||||
|
||||
pub fn none() -> Self {
|
||||
CUDAAttentionBackend(0)
|
||||
AttentionBackend(0)
|
||||
}
|
||||
|
||||
pub fn all() -> Self {
|
||||
@@ -38,7 +38,7 @@ impl CUDAAttentionBackend {
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr for CUDAAttentionBackend {
|
||||
impl BitOr for AttentionBackend {
|
||||
type Output = Self;
|
||||
fn bitor(self, rhs: Self) -> Self::Output {
|
||||
Self(rhs.0 | self.0)
|
||||
@@ -47,7 +47,7 @@ impl BitOr for CUDAAttentionBackend {
|
||||
|
||||
/// The type of search done for cuDNN convolution algorithms.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum CuDNNConvAlgorithmSearch {
|
||||
pub enum ConvAlgorithmSearch {
|
||||
/// Expensive exhaustive benchmarking using [`cudnnFindConvolutionForwardAlgorithmEx`][exhaustive].
|
||||
/// This function will attempt all possible algorithms for `cudnnConvolutionForward` to find the fastest algorithm.
|
||||
/// Exhaustive search trades off between memory usage and speed. The first execution of a graph will be slow while
|
||||
@@ -71,27 +71,27 @@ pub enum CuDNNConvAlgorithmSearch {
|
||||
/// > search algorithm is actually [`Exhaustive`].
|
||||
///
|
||||
/// [fwdalgo]: https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionFwdAlgo_t
|
||||
/// [`Exhaustive`]: CuDNNConvAlgorithmSearch::Exhaustive
|
||||
/// [`Heuristic`]: CuDNNConvAlgorithmSearch::Heuristic
|
||||
/// [`Exhaustive`]: ConvAlgorithmSearch::Exhaustive
|
||||
/// [`Heuristic`]: ConvAlgorithmSearch::Heuristic
|
||||
Default
|
||||
}
|
||||
|
||||
/// [CUDA execution provider](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for NVIDIA
|
||||
/// CUDA-enabled GPUs.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct CUDAExecutionProvider {
|
||||
pub struct CUDA {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; CUDAExecutionProvider);
|
||||
super::impl_ep!(arbitrary; CUDA);
|
||||
|
||||
impl CUDAExecutionProvider {
|
||||
impl CUDA {
|
||||
/// Configures which device the EP should use.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_device_id(0).build();
|
||||
/// let ep = ep::CUDA::default().with_device_id(0).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -104,12 +104,12 @@ impl CUDAExecutionProvider {
|
||||
/// Configure the size limit of the device memory arena in bytes.
|
||||
///
|
||||
/// This only controls how much memory can be allocated to the *arena* - actual memory usage may be higher due to
|
||||
/// internal CUDA allocations, like those required for different [`CuDNNConvAlgorithmSearch`] options.
|
||||
/// internal CUDA allocations, like those required for different [`ConvAlgorithmSearch`] options.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
|
||||
/// let ep = ep::CUDA::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -122,9 +122,9 @@ impl CUDAExecutionProvider {
|
||||
/// Configure the strategy for extending the device's memory arena.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::{cuda::CUDAExecutionProvider, ArenaExtendStrategy}, session::Session};
|
||||
/// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default()
|
||||
/// let ep = ep::CUDA::default()
|
||||
/// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested)
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
@@ -152,7 +152,7 @@ impl CUDAExecutionProvider {
|
||||
/// The default search algorithm, [`Exhaustive`][exh], will benchmark all available implementations and use the most
|
||||
/// performant one. This option is very resource intensive (both computationally on first run and peak-memory-wise),
|
||||
/// but ensures best performance. It is roughly equivalent to setting `torch.backends.cudnn.benchmark = True` with
|
||||
/// PyTorch. See also [`CUDAExecutionProvider::with_conv_max_workspace`] to configure how much memory the exhaustive
|
||||
/// PyTorch. See also [`CUDA::with_conv_max_workspace`] to configure how much memory the exhaustive
|
||||
/// search can use (the default is unlimited).
|
||||
///
|
||||
/// A less resource-intensive option is [`Heuristic`][heu]. Rather than benchmarking every implementation,
|
||||
@@ -164,41 +164,41 @@ impl CUDAExecutionProvider {
|
||||
/// is not the *default behavior* (that would be [`Exhaustive`][exh]).
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::{CUDAExecutionProvider, CuDNNConvAlgorithmSearch}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default()
|
||||
/// .with_conv_algorithm_search(CuDNNConvAlgorithmSearch::Heuristic)
|
||||
/// let ep = ep::CUDA::default()
|
||||
/// .with_conv_algorithm_search(ep::cuda::ConvAlgorithmSearch::Heuristic)
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// [exh]: CuDNNConvAlgorithmSearch::Exhaustive
|
||||
/// [heu]: CuDNNConvAlgorithmSearch::Heuristic
|
||||
/// [def]: CuDNNConvAlgorithmSearch::Default
|
||||
/// [exh]: ConvAlgorithmSearch::Exhaustive
|
||||
/// [heu]: ConvAlgorithmSearch::Heuristic
|
||||
/// [def]: ConvAlgorithmSearch::Default
|
||||
#[must_use]
|
||||
pub fn with_conv_algorithm_search(mut self, search: CuDNNConvAlgorithmSearch) -> Self {
|
||||
pub fn with_conv_algorithm_search(mut self, search: ConvAlgorithmSearch) -> Self {
|
||||
self.options.set(
|
||||
"cudnn_conv_algo_search",
|
||||
match search {
|
||||
CuDNNConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE",
|
||||
CuDNNConvAlgorithmSearch::Heuristic => "HEURISTIC",
|
||||
CuDNNConvAlgorithmSearch::Default => "DEFAULT"
|
||||
ConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE",
|
||||
ConvAlgorithmSearch::Heuristic => "HEURISTIC",
|
||||
ConvAlgorithmSearch::Default => "DEFAULT"
|
||||
}
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure whether the [`Exhaustive`][CuDNNConvAlgorithmSearch::Exhaustive] search can use as much memory as it
|
||||
/// Configure whether the [`Exhaustive`][ConvAlgorithmSearch::Exhaustive] search can use as much memory as it
|
||||
/// needs.
|
||||
///
|
||||
/// The default is `true`. When `false`, the memory used for the search is limited to 32 MB, which will impact its
|
||||
/// ability to find an optimal convolution algorithm.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_conv_max_workspace(false).build();
|
||||
/// let ep = ep::CUDA::default().with_conv_max_workspace(false).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -218,9 +218,9 @@ impl CUDAExecutionProvider {
|
||||
/// convolution operations that do not use 3-dimensional input shapes, or the *result* of such operations.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_conv1d_pad_to_nc1d(true).build();
|
||||
/// let ep = ep::CUDA::default().with_conv1d_pad_to_nc1d(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -245,9 +245,9 @@ impl CUDAExecutionProvider {
|
||||
/// Consult the [ONNX Runtime documentation on CUDA graphs](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#using-cuda-graphs-preview) for more information.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_cuda_graph(true).build();
|
||||
/// let ep = ep::CUDA::default().with_cuda_graph(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -262,9 +262,9 @@ impl CUDAExecutionProvider {
|
||||
/// `SkipLayerNorm`'s strict mode trades performance for accuracy. The default is `false` (strict mode disabled).
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_skip_layer_norm_strict_mode(true).build();
|
||||
/// let ep = ep::CUDA::default().with_skip_layer_norm_strict_mode(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -285,9 +285,9 @@ impl CUDAExecutionProvider {
|
||||
/// `torch.backends.cuda.matmul.allow_tf32 = True` or `torch.set_float32_matmul_precision("medium")` in PyTorch.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_tf32(true).build();
|
||||
/// let ep = ep::CUDA::default().with_tf32(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -303,9 +303,9 @@ impl CUDAExecutionProvider {
|
||||
/// convolution-heavy models on Tensor core-enabled GPUs may provide a significant performance improvement.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default().with_prefer_nhwc(true).build();
|
||||
/// let ep = ep::CUDA::default().with_prefer_nhwc(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -328,16 +328,18 @@ impl CUDAExecutionProvider {
|
||||
/// Configures the available backends used for `Attention` nodes.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::cuda::{CUDAExecutionProvider, CUDAAttentionBackend}, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = CUDAExecutionProvider::default()
|
||||
/// .with_attention_backend(CUDAAttentionBackend::FLASH_ATTENTION | CUDAAttentionBackend::TRT_FUSED_ATTENTION)
|
||||
/// let ep = ep::CUDA::default()
|
||||
/// .with_attention_backend(
|
||||
/// ep::cuda::AttentionBackend::FLASH_ATTENTION | ep::cuda::AttentionBackend::TRT_FUSED_ATTENTION
|
||||
/// )
|
||||
/// .build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn with_attention_backend(mut self, flags: CUDAAttentionBackend) -> Self {
|
||||
pub fn with_attention_backend(mut self, flags: AttentionBackend) -> Self {
|
||||
self.options.set("sdpa_kernel", flags.0.to_string());
|
||||
self
|
||||
}
|
||||
@@ -352,7 +354,7 @@ impl CUDAExecutionProvider {
|
||||
// https://github.com/microsoft/onnxruntime/blob/fe8a10caa40f64a8fbd144e7049cf5b14c65542d/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc#L17
|
||||
}
|
||||
|
||||
impl ExecutionProvider for CUDAExecutionProvider {
|
||||
impl ExecutionProvider for CUDA {
|
||||
fn name(&self) -> &'static str {
|
||||
"CUDAExecutionProvider"
|
||||
}
|
||||
@@ -435,16 +437,16 @@ pub const CUDNN_DYLIBS: &[&str] = &[
|
||||
///
|
||||
/// ```
|
||||
/// # use std::path::Path;
|
||||
/// use ort::execution_providers::cuda;
|
||||
/// use ort::ep;
|
||||
///
|
||||
/// let cuda_root = Path::new(r#"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin"#);
|
||||
/// let cudnn_root = Path::new(r#"D:\cudnn_9.8.0"#);
|
||||
///
|
||||
/// // Load CUDA & cuDNN
|
||||
/// let _ = cuda::preload_dylibs(Some(cuda_root), Some(cudnn_root));
|
||||
/// let _ = ep::cuda::preload_dylibs(Some(cuda_root), Some(cudnn_root));
|
||||
///
|
||||
/// // Only preload cuDNN
|
||||
/// let _ = cuda::preload_dylibs(None, Some(cudnn_root));
|
||||
/// let _ = ep::cuda::preload_dylibs(None, Some(cudnn_root));
|
||||
/// ```
|
||||
#[cfg_attr(docsrs, doc(cfg(any(feature = "preload-dylibs", feature = "load-dynamic"))))]
|
||||
#[cfg(feature = "preload-dylibs")]
|
||||
@@ -9,10 +9,10 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// with dynamically sized inputs, you can override individual dimensions by constructing the session with
|
||||
/// [`SessionBuilder::with_dimension_override`]:
|
||||
/// ```no_run
|
||||
/// # use ort::{execution_providers::directml::DirectMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?
|
||||
/// .with_execution_providers([DirectMLExecutionProvider::default().build()])?
|
||||
/// .with_execution_providers([ep::DirectML::default().build()])?
|
||||
/// .with_dimension_override("batch", 1)?
|
||||
/// .with_dimension_override("seq_len", 512)?
|
||||
/// .commit_from_file("gpt2.onnx")?;
|
||||
@@ -20,19 +20,19 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct DirectMLExecutionProvider {
|
||||
pub struct DirectML {
|
||||
device_id: i32
|
||||
}
|
||||
|
||||
super::impl_ep!(DirectMLExecutionProvider);
|
||||
super::impl_ep!(DirectML);
|
||||
|
||||
impl DirectMLExecutionProvider {
|
||||
impl DirectML {
|
||||
/// Configures which device the EP should use.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::directml::DirectMLExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = DirectMLExecutionProvider::default().with_device_id(1).build();
|
||||
/// let ep = ep::DirectML::default().with_device_id(1).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -43,7 +43,7 @@ impl DirectMLExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for DirectMLExecutionProvider {
|
||||
impl ExecutionProvider for DirectML {
|
||||
fn name(&self) -> &'static str {
|
||||
"DmlExecutionProvider"
|
||||
}
|
||||
@@ -6,7 +6,7 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// [MIGraphX execution provider](https://onnxruntime.ai/docs/execution-providers/MIGraphX-ExecutionProvider.html) for
|
||||
/// hardware acceleration with AMD GPUs.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct MIGraphXExecutionProvider {
|
||||
pub struct MIGraphX {
|
||||
device_id: i32,
|
||||
enable_fp16: bool,
|
||||
enable_int8: bool,
|
||||
@@ -17,15 +17,15 @@ pub struct MIGraphXExecutionProvider {
|
||||
exhaustive_tune: bool
|
||||
}
|
||||
|
||||
super::impl_ep!(MIGraphXExecutionProvider);
|
||||
super::impl_ep!(MIGraphX);
|
||||
|
||||
impl MIGraphXExecutionProvider {
|
||||
impl MIGraphX {
|
||||
/// Configures which device the EP should use.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default().with_device_id(0).build();
|
||||
/// let ep = ep::MIGraphX::default().with_device_id(0).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -38,9 +38,9 @@ impl MIGraphXExecutionProvider {
|
||||
/// Enable FP16 quantization for the model.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default().with_fp16(true).build();
|
||||
/// let ep = ep::MIGraphX::default().with_fp16(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -51,15 +51,12 @@ impl MIGraphXExecutionProvider {
|
||||
}
|
||||
|
||||
/// Enable 8-bit integer quantization for the model. Requires
|
||||
/// [`MIGraphXExecutionProvider::with_int8_calibration_table`] to be set.
|
||||
/// [`MIGraphX::with_int8_calibration_table`] to be set.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default()
|
||||
/// .with_int8(true)
|
||||
/// .with_int8_calibration_table("...", false)
|
||||
/// .build();
|
||||
/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -75,12 +72,9 @@ impl MIGraphXExecutionProvider {
|
||||
/// for the JSON dump format.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default()
|
||||
/// .with_int8(true)
|
||||
/// .with_int8_calibration_table("...", false)
|
||||
/// .build();
|
||||
/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -93,12 +87,12 @@ impl MIGraphXExecutionProvider {
|
||||
|
||||
/// Save the compiled MIGraphX model to the given path.
|
||||
///
|
||||
/// The compiled model can then be loaded in subsequent runs with [`MIGraphXExecutionProvider::with_load_model`].
|
||||
/// The compiled model can then be loaded in subsequent runs with [`MIGraphX::with_load_model`].
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default().with_save_model("./compiled_model.mxr").build();
|
||||
/// let ep = ep::MIGraphX::default().with_save_model("./compiled_model.mxr").build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -108,13 +102,13 @@ impl MIGraphXExecutionProvider {
|
||||
self
|
||||
}
|
||||
|
||||
/// Load the compiled MIGraphX model (previously generated by [`MIGraphXExecutionProvider::with_save_model`]) from
|
||||
/// Load the compiled MIGraphX model (previously generated by [`MIGraphX::with_save_model`]) from
|
||||
/// the given path.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default().with_load_model("./compiled_model.mxr").build();
|
||||
/// let ep = ep::MIGraphX::default().with_load_model("./compiled_model.mxr").build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -127,9 +121,9 @@ impl MIGraphXExecutionProvider {
|
||||
/// Enable exhaustive tuning; trades loading time for inference performance.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = MIGraphXExecutionProvider::default().with_exhaustive_tune(true).build();
|
||||
/// let ep = ep::MIGraphX::default().with_exhaustive_tune(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -140,7 +134,7 @@ impl MIGraphXExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for MIGraphXExecutionProvider {
|
||||
impl ExecutionProvider for MIGraphX {
|
||||
fn name(&self) -> &'static str {
|
||||
"MIGraphXExecutionProvider"
|
||||
}
|
||||
@@ -3,11 +3,11 @@
|
||||
//! Sessions can be configured with execution providers via [`SessionBuilder::with_execution_providers`]:
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use ort::{execution_providers::CUDAExecutionProvider, session::Session};
|
||||
//! use ort::{ep, session::Session};
|
||||
//!
|
||||
//! fn main() -> ort::Result<()> {
|
||||
//! let session = Session::builder()?
|
||||
//! .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
//! .with_execution_providers([ep::CUDA::default().build()])?
|
||||
//! .commit_from_file("model.onnx")?;
|
||||
//!
|
||||
//! Ok(())
|
||||
@@ -24,47 +24,47 @@ use core::{
|
||||
use crate::{char_p_to_string, error::Result, ortsys, session::builder::SessionBuilder, util::MiniMap};
|
||||
|
||||
pub mod cpu;
|
||||
pub use self::cpu::CPUExecutionProvider;
|
||||
pub use self::cpu::CPU;
|
||||
pub mod cuda;
|
||||
pub use self::cuda::CUDAExecutionProvider;
|
||||
pub use self::cuda::CUDA;
|
||||
pub mod tensorrt;
|
||||
pub use self::tensorrt::TensorRTExecutionProvider;
|
||||
pub use self::tensorrt::TensorRT;
|
||||
pub mod onednn;
|
||||
pub use self::onednn::OneDNNExecutionProvider;
|
||||
pub use self::onednn::OneDNN;
|
||||
pub mod acl;
|
||||
pub use self::acl::ACLExecutionProvider;
|
||||
pub use self::acl::ACL;
|
||||
pub mod openvino;
|
||||
pub use self::openvino::OpenVINOExecutionProvider;
|
||||
pub use self::openvino::OpenVINO;
|
||||
pub mod coreml;
|
||||
pub use self::coreml::CoreMLExecutionProvider;
|
||||
pub use self::coreml::CoreML;
|
||||
pub mod rocm;
|
||||
pub use self::rocm::ROCmExecutionProvider;
|
||||
pub use self::rocm::ROCm;
|
||||
pub mod cann;
|
||||
pub use self::cann::CANNExecutionProvider;
|
||||
pub use self::cann::CANN;
|
||||
pub mod directml;
|
||||
pub use self::directml::DirectMLExecutionProvider;
|
||||
pub use self::directml::DirectML;
|
||||
pub mod tvm;
|
||||
pub use self::tvm::TVMExecutionProvider;
|
||||
pub use self::tvm::TVM;
|
||||
pub mod nnapi;
|
||||
pub use self::nnapi::NNAPIExecutionProvider;
|
||||
pub use self::nnapi::NNAPI;
|
||||
pub mod qnn;
|
||||
pub use self::qnn::QNNExecutionProvider;
|
||||
pub use self::qnn::QNN;
|
||||
pub mod xnnpack;
|
||||
pub use self::xnnpack::XNNPACKExecutionProvider;
|
||||
pub use self::xnnpack::XNNPACK;
|
||||
pub mod armnn;
|
||||
pub use self::armnn::ArmNNExecutionProvider;
|
||||
pub use self::armnn::ArmNN;
|
||||
pub mod migraphx;
|
||||
pub use self::migraphx::MIGraphXExecutionProvider;
|
||||
pub use self::migraphx::MIGraphX;
|
||||
pub mod vitis;
|
||||
pub use self::vitis::VitisAIExecutionProvider;
|
||||
pub use self::vitis::Vitis;
|
||||
pub mod rknpu;
|
||||
pub use self::rknpu::RKNPUExecutionProvider;
|
||||
pub use self::rknpu::RKNPU;
|
||||
pub mod webgpu;
|
||||
pub use self::webgpu::WebGPUExecutionProvider;
|
||||
pub use self::webgpu::WebGPU;
|
||||
pub mod azure;
|
||||
pub use self::azure::AzureExecutionProvider;
|
||||
pub use self::azure::Azure;
|
||||
pub mod nvrtx;
|
||||
pub use self::nvrtx::NVRTXExecutionProvider;
|
||||
pub use self::nvrtx::NVRTX;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub mod wasm;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -82,14 +82,14 @@ pub trait ExecutionProvider: Send + Sync {
|
||||
/// Returns the identifier of this execution provider used internally by ONNX Runtime.
|
||||
///
|
||||
/// This is the same as what's used in ONNX Runtime's Python API to register this execution provider, i.e.
|
||||
/// [`TVMExecutionProvider`]'s identifier is `TvmExecutionProvider`.
|
||||
/// [`TVM`]'s identifier is `TvmExecutionProvider`.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Returns whether this execution provider is supported on this platform.
|
||||
///
|
||||
/// For example, the CoreML execution provider implements this as:
|
||||
/// ```ignore
|
||||
/// impl ExecutionProvider for CoreMLExecutionProvider {
|
||||
/// impl ExecutionProvider for CoreML {
|
||||
/// fn supported_by_platform() -> bool {
|
||||
/// cfg!(target_vendor = "apple")
|
||||
/// }
|
||||
@@ -321,9 +321,9 @@ pub(crate) use define_ep_register;
|
||||
|
||||
macro_rules! impl_ep {
|
||||
(arbitrary; $symbol:ident) => {
|
||||
$crate::execution_providers::impl_ep!($symbol);
|
||||
$crate::ep::impl_ep!($symbol);
|
||||
|
||||
impl $crate::execution_providers::ArbitrarilyConfigurableExecutionProvider for $symbol {
|
||||
impl $crate::ep::ArbitrarilyConfigurableExecutionProvider for $symbol {
|
||||
fn with_arbitrary_config(mut self, key: impl ::alloc::string::ToString, value: impl ::alloc::string::ToString) -> Self {
|
||||
self.options.set(key.to_string(), value.to_string());
|
||||
self
|
||||
@@ -333,14 +333,14 @@ macro_rules! impl_ep {
|
||||
($symbol:ident) => {
|
||||
impl $symbol {
|
||||
#[must_use]
|
||||
pub fn build(self) -> $crate::execution_providers::ExecutionProviderDispatch {
|
||||
pub fn build(self) -> $crate::ep::ExecutionProviderDispatch {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<$symbol> for $crate::execution_providers::ExecutionProviderDispatch {
|
||||
impl From<$symbol> for $crate::ep::ExecutionProviderDispatch {
|
||||
fn from(value: $symbol) -> Self {
|
||||
$crate::execution_providers::ExecutionProviderDispatch::new(value)
|
||||
$crate::ep::ExecutionProviderDispatch::new(value)
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -381,3 +381,75 @@ pub(crate) fn apply_execution_providers(session_builder: &mut SessionBuilder, ep
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[deprecated = "import `ort::ep::ACL` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::acl::ACL as ACLExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::ArmNN` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::armnn::ArmNN as ArmNNExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::Azure` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::azure::Azure as AzureExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::CANN` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::cann::CANN as CANNExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::CoreML` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::coreml::CoreML as CoreMLExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::CPU` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::cpu::CPU as CPUExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::CUDA` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::cuda::CUDA as CUDAExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::DirectML` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::directml::DirectML as DirectMLExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::MIGraphX` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::migraphx::MIGraphX as MIGraphXExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::NNAPI` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::nnapi::NNAPI as NNAPIExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::NVRTX` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::nvrtx::NVRTX as NVRTXExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::OneDNN` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::onednn::OneDNN as OneDNNExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::OpenVINO` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::openvino::OpenVINO as OpenVINOExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::QNN` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::qnn::QNN as QNNExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::RKNPU` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::rknpu::RKNPU as RKNPUExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::ROCm` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::rocm::ROCm as ROCmExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::TensorRT` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::tensorrt::TensorRT as TensorRTExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::TVM` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::tvm::TVM as TVMExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::Vitis` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::vitis::Vitis as VitisAIExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::WASM` instead"]
|
||||
#[doc(hidden)]
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use self::wasm::WASM as WASMExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::WebGPU` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::webgpu::WebGPU as WebGPUExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::WebNN` instead"]
|
||||
#[doc(hidden)]
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use self::webnn::WebNN as WebNNExecutionProvider;
|
||||
#[deprecated = "import `ort::ep::XNNPACK` instead"]
|
||||
#[doc(hidden)]
|
||||
pub use self::xnnpack::XNNPACK as XNNPACKExecutionProvider;
|
||||
@@ -2,16 +2,16 @@ use super::{ExecutionProvider, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct NNAPIExecutionProvider {
|
||||
pub struct NNAPI {
|
||||
use_fp16: bool,
|
||||
use_nchw: bool,
|
||||
disable_cpu: bool,
|
||||
cpu_only: bool
|
||||
}
|
||||
|
||||
super::impl_ep!(NNAPIExecutionProvider);
|
||||
super::impl_ep!(NNAPI);
|
||||
|
||||
impl NNAPIExecutionProvider {
|
||||
impl NNAPI {
|
||||
/// Use fp16 relaxation in NNAPI EP. This may improve performance but can also reduce accuracy due to the lower
|
||||
/// precision.
|
||||
#[must_use]
|
||||
@@ -49,7 +49,7 @@ impl NNAPIExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for NNAPIExecutionProvider {
|
||||
impl ExecutionProvider for NNAPI {
|
||||
fn name(&self) -> &'static str {
|
||||
"NnapiExecutionProvider"
|
||||
}
|
||||
@@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct NVRTXExecutionProvider {
|
||||
pub struct NVRTX {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; NVRTXExecutionProvider);
|
||||
super::impl_ep!(arbitrary; NVRTX);
|
||||
|
||||
impl NVRTXExecutionProvider {
|
||||
impl NVRTX {
|
||||
pub fn with_device_id(mut self, device_id: u32) -> Self {
|
||||
self.options.set("ep.nvtensorrtrtxexecutionprovider.device_id", device_id.to_string());
|
||||
self
|
||||
@@ -23,7 +23,7 @@ impl NVRTXExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for NVRTXExecutionProvider {
|
||||
impl ExecutionProvider for NVRTX {
|
||||
fn name(&self) -> &'static str {
|
||||
"NvTensorRTRTXExecutionProvider"
|
||||
}
|
||||
@@ -4,20 +4,20 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// [oneDNN/DNNL execution provider](https://onnxruntime.ai/docs/execution-providers/oneDNN-ExecutionProvider.html) for
|
||||
/// Intel CPUs & iGPUs.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
#[doc(alias = "DNNLExecutionProvider")]
|
||||
pub struct OneDNNExecutionProvider {
|
||||
#[doc(alias = "DNNL")]
|
||||
pub struct OneDNN {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; OneDNNExecutionProvider);
|
||||
super::impl_ep!(arbitrary; OneDNN);
|
||||
|
||||
impl OneDNNExecutionProvider {
|
||||
impl OneDNN {
|
||||
/// Enable/disable the usage of the arena allocator.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::onednn::OneDNNExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = OneDNNExecutionProvider::default().with_arena_allocator(true).build();
|
||||
/// let ep = ep::OneDNN::default().with_arena_allocator(true).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -28,7 +28,7 @@ impl OneDNNExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for OneDNNExecutionProvider {
|
||||
impl ExecutionProvider for OneDNN {
|
||||
fn name(&self) -> &'static str {
|
||||
"DnnlExecutionProvider"
|
||||
}
|
||||
@@ -6,22 +6,22 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// [OpenVINO execution provider](https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html) for
|
||||
/// Intel CPUs/GPUs/NPUs.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct OpenVINOExecutionProvider {
|
||||
pub struct OpenVINO {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; OpenVINOExecutionProvider);
|
||||
super::impl_ep!(arbitrary; OpenVINO);
|
||||
|
||||
impl OpenVINOExecutionProvider {
|
||||
impl OpenVINO {
|
||||
/// Overrides the accelerator hardware type and precision.
|
||||
///
|
||||
/// `device_type` should be in the format `CPU`, `NPU`, `GPU`, `GPU.0`, `GPU.1`, etc. Heterogenous combinations are
|
||||
/// also supported in the format `HETERO:NPU,GPU`.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{execution_providers::openvino::OpenVINOExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = OpenVINOExecutionProvider::default().with_device_type("GPU.0").build();
|
||||
/// let ep = ep::OpenVINO::default().with_device_type("GPU.0").build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -81,7 +81,7 @@ impl OpenVINOExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for OpenVINOExecutionProvider {
|
||||
impl ExecutionProvider for OpenVINO {
|
||||
fn name(&self) -> &'static str {
|
||||
"OpenVINOExecutionProvider"
|
||||
}
|
||||
@@ -4,7 +4,7 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QNNPerformanceMode {
|
||||
pub enum PerformanceMode {
|
||||
Default,
|
||||
Burst,
|
||||
Balanced,
|
||||
@@ -17,43 +17,43 @@ pub enum QNNPerformanceMode {
|
||||
SustainedHighPerformance
|
||||
}
|
||||
|
||||
impl QNNPerformanceMode {
|
||||
impl PerformanceMode {
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
QNNPerformanceMode::Default => "default",
|
||||
QNNPerformanceMode::Burst => "burst",
|
||||
QNNPerformanceMode::Balanced => "balanced",
|
||||
QNNPerformanceMode::HighPerformance => "high_performance",
|
||||
QNNPerformanceMode::HighPowerSaver => "high_power_saver",
|
||||
QNNPerformanceMode::LowPowerSaver => "low_power_saver",
|
||||
QNNPerformanceMode::LowBalanced => "low_balanced",
|
||||
QNNPerformanceMode::PowerSaver => "power_saver",
|
||||
QNNPerformanceMode::ExtremePowerSaver => "extreme_power_saver",
|
||||
QNNPerformanceMode::SustainedHighPerformance => "sustained_high_performance"
|
||||
PerformanceMode::Default => "default",
|
||||
PerformanceMode::Burst => "burst",
|
||||
PerformanceMode::Balanced => "balanced",
|
||||
PerformanceMode::HighPerformance => "high_performance",
|
||||
PerformanceMode::HighPowerSaver => "high_power_saver",
|
||||
PerformanceMode::LowPowerSaver => "low_power_saver",
|
||||
PerformanceMode::LowBalanced => "low_balanced",
|
||||
PerformanceMode::PowerSaver => "power_saver",
|
||||
PerformanceMode::ExtremePowerSaver => "extreme_power_saver",
|
||||
PerformanceMode::SustainedHighPerformance => "sustained_high_performance"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QNNProfilingLevel {
|
||||
pub enum ProfilingLevel {
|
||||
Off,
|
||||
Basic,
|
||||
Detailed
|
||||
}
|
||||
|
||||
impl QNNProfilingLevel {
|
||||
impl ProfilingLevel {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
QNNProfilingLevel::Off => "off",
|
||||
QNNProfilingLevel::Basic => "basic",
|
||||
QNNProfilingLevel::Detailed => "detailed"
|
||||
ProfilingLevel::Off => "off",
|
||||
ProfilingLevel::Basic => "basic",
|
||||
ProfilingLevel::Detailed => "detailed"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum QNNContextPriority {
|
||||
pub enum ContextPriority {
|
||||
Low,
|
||||
#[default]
|
||||
Normal,
|
||||
@@ -61,25 +61,25 @@ pub enum QNNContextPriority {
|
||||
High
|
||||
}
|
||||
|
||||
impl QNNContextPriority {
|
||||
impl ContextPriority {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
QNNContextPriority::Low => "low",
|
||||
QNNContextPriority::Normal => "normal",
|
||||
QNNContextPriority::NormalHigh => "normal_high",
|
||||
QNNContextPriority::High => "high"
|
||||
ContextPriority::Low => "low",
|
||||
ContextPriority::Normal => "normal",
|
||||
ContextPriority::NormalHigh => "normal_high",
|
||||
ContextPriority::High => "high"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct QNNExecutionProvider {
|
||||
pub struct QNN {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; QNNExecutionProvider);
|
||||
super::impl_ep!(arbitrary; QNN);
|
||||
|
||||
impl QNNExecutionProvider {
|
||||
impl QNN {
|
||||
/// The file path to QNN backend library. On Linux/Android, this is `libQnnCpu.so` to use the CPU backend,
|
||||
/// or `libQnnHtp.so` to use the accelerated backend.
|
||||
#[must_use]
|
||||
@@ -89,7 +89,7 @@ impl QNNExecutionProvider {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_profiling(mut self, level: QNNProfilingLevel) -> Self {
|
||||
pub fn with_profiling(mut self, level: ProfilingLevel) -> Self {
|
||||
self.options.set("profiling_level", level.as_str());
|
||||
self
|
||||
}
|
||||
@@ -114,7 +114,7 @@ impl QNNExecutionProvider {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_performance_mode(mut self, mode: QNNPerformanceMode) -> Self {
|
||||
pub fn with_performance_mode(mut self, mode: PerformanceMode) -> Self {
|
||||
self.options.set("htp_performance_mode", mode.as_str());
|
||||
self
|
||||
}
|
||||
@@ -126,7 +126,7 @@ impl QNNExecutionProvider {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_context_priority(mut self, priority: QNNContextPriority) -> Self {
|
||||
pub fn with_context_priority(mut self, priority: ContextPriority) -> Self {
|
||||
self.options.set("qnn_context_priority", priority.as_str());
|
||||
self
|
||||
}
|
||||
@@ -174,7 +174,7 @@ impl QNNExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for QNNExecutionProvider {
|
||||
impl ExecutionProvider for QNN {
|
||||
fn name(&self) -> &'static str {
|
||||
"QNNExecutionProvider"
|
||||
}
|
||||
@@ -2,11 +2,11 @@ use super::{ExecutionProvider, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct RKNPUExecutionProvider {}
|
||||
pub struct RKNPU {}
|
||||
|
||||
super::impl_ep!(RKNPUExecutionProvider);
|
||||
super::impl_ep!(RKNPU);
|
||||
|
||||
impl ExecutionProvider for RKNPUExecutionProvider {
|
||||
impl ExecutionProvider for RKNPU {
|
||||
fn name(&self) -> &'static str {
|
||||
"RknpuExecutionProvider"
|
||||
}
|
||||
@@ -5,13 +5,13 @@ use super::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderOptions, Re
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ROCmExecutionProvider {
|
||||
pub struct ROCm {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; ROCmExecutionProvider);
|
||||
super::impl_ep!(arbitrary; ROCm);
|
||||
|
||||
impl ROCmExecutionProvider {
|
||||
impl ROCm {
|
||||
#[must_use]
|
||||
pub fn with_device_id(mut self, device_id: i32) -> Self {
|
||||
self.options.set("device_id", device_id.to_string());
|
||||
@@ -86,7 +86,7 @@ impl ROCmExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for ROCmExecutionProvider {
|
||||
impl ExecutionProvider for ROCm {
|
||||
fn name(&self) -> &'static str {
|
||||
"ROCMExecutionProvider"
|
||||
}
|
||||
@@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TensorRTExecutionProvider {
|
||||
pub struct TensorRT {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; TensorRTExecutionProvider);
|
||||
super::impl_ep!(arbitrary; TensorRT);
|
||||
|
||||
impl TensorRTExecutionProvider {
|
||||
impl TensorRT {
|
||||
#[must_use]
|
||||
pub fn with_device_id(mut self, device_id: i32) -> Self {
|
||||
self.options.set("device_id", device_id.to_string());
|
||||
@@ -254,7 +254,7 @@ impl TensorRTExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for TensorRTExecutionProvider {
|
||||
impl ExecutionProvider for TensorRT {
|
||||
fn name(&self) -> &'static str {
|
||||
"TensorrtExecutionProvider"
|
||||
}
|
||||
@@ -4,22 +4,22 @@ use super::{ExecutionProvider, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TVMExecutorType {
|
||||
pub enum ExecutorType {
|
||||
GraphExecutor,
|
||||
VirtualMachine
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TVMTuningType {
|
||||
pub enum TuningType {
|
||||
AutoTVM,
|
||||
Ansor
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TVMExecutionProvider {
|
||||
pub struct TVM {
|
||||
/// Executor type used by TVM. There is a choice between two types, `GraphExecutor` and `VirtualMachine`. Default is
|
||||
/// [`TVMExecutorType::VirtualMachine`].
|
||||
pub executor: Option<TVMExecutorType>,
|
||||
/// [`ExecutorType::VirtualMachine`].
|
||||
pub executor: Option<ExecutorType>,
|
||||
/// Path to folder with set of files (`.ro-`, `.so`/`.dll`-files and weights) obtained after model tuning.
|
||||
pub so_folder: Option<String>,
|
||||
/// Whether or not to perform a hash check on the model obtained in the `so_folder`.
|
||||
@@ -34,7 +34,7 @@ pub struct TVMExecutionProvider {
|
||||
/// `true` is recommended for best performance and is the default.
|
||||
pub freeze_weights: Option<bool>,
|
||||
pub to_nhwc: Option<bool>,
|
||||
pub tuning_type: Option<TVMTuningType>,
|
||||
pub tuning_type: Option<TuningType>,
|
||||
/// Path to AutoTVM or Ansor tuning file which gives specifications for given model and target for the best
|
||||
/// performance.
|
||||
pub tuning_file_path: Option<String>,
|
||||
@@ -42,9 +42,9 @@ pub struct TVMExecutionProvider {
|
||||
pub input_shapes: Option<String>
|
||||
}
|
||||
|
||||
super::impl_ep!(TVMExecutionProvider);
|
||||
super::impl_ep!(TVM);
|
||||
|
||||
impl ExecutionProvider for TVMExecutionProvider {
|
||||
impl ExecutionProvider for TVM {
|
||||
fn name(&self) -> &'static str {
|
||||
"TvmExecutionProvider"
|
||||
}
|
||||
@@ -66,8 +66,8 @@ impl ExecutionProvider for TVMExecutionProvider {
|
||||
option_string.push(format!(
|
||||
"executor:{}",
|
||||
match executor {
|
||||
TVMExecutorType::GraphExecutor => "graph",
|
||||
TVMExecutorType::VirtualMachine => "vm"
|
||||
ExecutorType::GraphExecutor => "graph",
|
||||
ExecutorType::VirtualMachine => "vm"
|
||||
}
|
||||
));
|
||||
}
|
||||
@@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct VitisAIExecutionProvider {
|
||||
pub struct Vitis {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; VitisAIExecutionProvider);
|
||||
super::impl_ep!(arbitrary; Vitis);
|
||||
|
||||
impl VitisAIExecutionProvider {
|
||||
impl Vitis {
|
||||
pub fn with_config_file(mut self, config_file: impl ToString) -> Self {
|
||||
self.options.set("config_file", config_file.to_string());
|
||||
self
|
||||
@@ -27,7 +27,7 @@ impl VitisAIExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for VitisAIExecutionProvider {
|
||||
impl ExecutionProvider for Vitis {
|
||||
fn name(&self) -> &'static str {
|
||||
"VitisAIExecutionProvider"
|
||||
}
|
||||
@@ -2,13 +2,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct WASMExecutionProvider {
|
||||
pub struct WASM {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; WASMExecutionProvider);
|
||||
super::impl_ep!(arbitrary; WASM);
|
||||
|
||||
impl ExecutionProvider for WASMExecutionProvider {
|
||||
impl ExecutionProvider for WASM {
|
||||
fn name(&self) -> &'static str {
|
||||
"WASMExecutionProvider"
|
||||
}
|
||||
@@ -4,84 +4,84 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{error::Result, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WebGPUPreferredLayout {
|
||||
pub enum PreferredLayout {
|
||||
NCHW,
|
||||
NHWC
|
||||
}
|
||||
|
||||
impl WebGPUPreferredLayout {
|
||||
impl PreferredLayout {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WebGPUPreferredLayout::NCHW => "NCHW",
|
||||
WebGPUPreferredLayout::NHWC => "NHWC"
|
||||
PreferredLayout::NCHW => "NCHW",
|
||||
PreferredLayout::NHWC => "NHWC"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WebGPUDawnBackendType {
|
||||
pub enum DawnBackendType {
|
||||
Vulkan,
|
||||
D3D12
|
||||
}
|
||||
|
||||
impl WebGPUDawnBackendType {
|
||||
impl DawnBackendType {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WebGPUDawnBackendType::Vulkan => "Vulkan",
|
||||
WebGPUDawnBackendType::D3D12 => "D3D12"
|
||||
DawnBackendType::Vulkan => "Vulkan",
|
||||
DawnBackendType::D3D12 => "D3D12"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WebGPUBufferCacheMode {
|
||||
pub enum BufferCacheMode {
|
||||
Disabled,
|
||||
LazyRelease,
|
||||
Simple,
|
||||
Bucket
|
||||
}
|
||||
|
||||
impl WebGPUBufferCacheMode {
|
||||
impl BufferCacheMode {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WebGPUBufferCacheMode::Disabled => "disabled",
|
||||
WebGPUBufferCacheMode::LazyRelease => "lazyRelease",
|
||||
WebGPUBufferCacheMode::Simple => "simple",
|
||||
WebGPUBufferCacheMode::Bucket => "bucket"
|
||||
BufferCacheMode::Disabled => "disabled",
|
||||
BufferCacheMode::LazyRelease => "lazyRelease",
|
||||
BufferCacheMode::Simple => "simple",
|
||||
BufferCacheMode::Bucket => "bucket"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WebGPUValidationMode {
|
||||
pub enum ValidationMode {
|
||||
Disabled,
|
||||
WgpuOnly,
|
||||
Basic,
|
||||
Full
|
||||
}
|
||||
|
||||
impl WebGPUValidationMode {
|
||||
impl ValidationMode {
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WebGPUValidationMode::Disabled => "disabled",
|
||||
WebGPUValidationMode::WgpuOnly => "wgpuOnly",
|
||||
WebGPUValidationMode::Basic => "basic",
|
||||
WebGPUValidationMode::Full => "full"
|
||||
ValidationMode::Disabled => "disabled",
|
||||
ValidationMode::WgpuOnly => "wgpuOnly",
|
||||
ValidationMode::Basic => "basic",
|
||||
ValidationMode::Full => "full"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct WebGPUExecutionProvider {
|
||||
pub struct WebGPU {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; WebGPUExecutionProvider);
|
||||
super::impl_ep!(arbitrary; WebGPU);
|
||||
|
||||
impl WebGPUExecutionProvider {
|
||||
impl WebGPU {
|
||||
#[must_use]
|
||||
pub fn with_preferred_layout(mut self, layout: WebGPUPreferredLayout) -> Self {
|
||||
pub fn with_preferred_layout(mut self, layout: PreferredLayout) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.preferredLayout", layout.as_str());
|
||||
self
|
||||
}
|
||||
@@ -100,7 +100,7 @@ impl WebGPUExecutionProvider {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_dawn_backend_type(mut self, backend_type: WebGPUDawnBackendType) -> Self {
|
||||
pub fn with_dawn_backend_type(mut self, backend_type: DawnBackendType) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.dawnBackendType", backend_type.as_str());
|
||||
self
|
||||
}
|
||||
@@ -112,31 +112,31 @@ impl WebGPUExecutionProvider {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_storage_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
|
||||
pub fn with_storage_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.storageBufferCacheMode", mode.as_str());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_uniform_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
|
||||
pub fn with_uniform_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.uniformBufferCacheMode", mode.as_str());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_query_resolve_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
|
||||
pub fn with_query_resolve_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.queryResolveBufferCacheMode", mode.as_str());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_default_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self {
|
||||
pub fn with_default_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.defaultBufferCacheMode", mode.as_str());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_validation_mode(mut self, mode: WebGPUValidationMode) -> Self {
|
||||
pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self {
|
||||
self.options.set("ep.webgpuexecutionprovider.validationMode", mode.as_str());
|
||||
self
|
||||
}
|
||||
@@ -155,7 +155,7 @@ impl WebGPUExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for WebGPUExecutionProvider {
|
||||
impl ExecutionProvider for WebGPU {
|
||||
fn name(&self) -> &'static str {
|
||||
"WebGpuExecutionProvider"
|
||||
}
|
||||
@@ -4,56 +4,56 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
|
||||
use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WebNNPowerPreference {
|
||||
pub enum PowerPreference {
|
||||
#[default]
|
||||
Default,
|
||||
HighPerformance,
|
||||
LowPower
|
||||
}
|
||||
|
||||
impl WebNNPowerPreference {
|
||||
impl PowerPreference {
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WebNNPowerPreference::Default => "default",
|
||||
WebNNPowerPreference::HighPerformance => "high-performance",
|
||||
WebNNPowerPreference::LowPower => "low-power"
|
||||
PowerPreference::Default => "default",
|
||||
PowerPreference::HighPerformance => "high-performance",
|
||||
PowerPreference::LowPower => "low-power"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WebNNDeviceType {
|
||||
pub enum DeviceType {
|
||||
CPU,
|
||||
GPU,
|
||||
NPU
|
||||
}
|
||||
|
||||
impl WebNNDeviceType {
|
||||
impl DeviceType {
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WebNNDeviceType::CPU => "cpu",
|
||||
WebNNDeviceType::GPU => "gpu",
|
||||
WebNNDeviceType::NPU => "npu"
|
||||
DeviceType::CPU => "cpu",
|
||||
DeviceType::GPU => "gpu",
|
||||
DeviceType::NPU => "npu"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct WebNNExecutionProvider {
|
||||
pub struct WebNN {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
impl WebNNExecutionProvider {
|
||||
impl WebNN {
|
||||
#[must_use]
|
||||
pub fn with_device_type(mut self, device_type: WebNNDeviceType) -> Self {
|
||||
pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
|
||||
self.options.set("deviceType", device_type.as_str());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_power_preference(mut self, pref: WebNNPowerPreference) -> Self {
|
||||
pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
|
||||
self.options.set("powerPreference", pref.as_str());
|
||||
self
|
||||
}
|
||||
@@ -65,9 +65,9 @@ impl WebNNExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; WebNNExecutionProvider);
|
||||
super::impl_ep!(arbitrary; WebNN);
|
||||
|
||||
impl ExecutionProvider for WebNNExecutionProvider {
|
||||
impl ExecutionProvider for WebNN {
|
||||
fn name(&self) -> &'static str {
|
||||
"WebNNExecutionProvider"
|
||||
}
|
||||
@@ -13,12 +13,12 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// disable the session intra-op threadpool to reduce contention:
|
||||
/// ```no_run
|
||||
/// # use core::num::NonZeroUsize;
|
||||
/// # use ort::{execution_providers::xnnpack::XNNPACKExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?
|
||||
/// .with_intra_op_spinning(false)?
|
||||
/// .with_intra_threads(1)?
|
||||
/// .with_execution_providers([XNNPACKExecutionProvider::default()
|
||||
/// .with_execution_providers([ep::XNNPACK::default()
|
||||
/// .with_intra_op_num_threads(NonZeroUsize::new(4).unwrap())
|
||||
/// .build()])?
|
||||
/// .commit_from_file("model.onnx")?;
|
||||
@@ -26,22 +26,20 @@ use crate::{error::Result, session::builder::SessionBuilder};
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct XNNPACKExecutionProvider {
|
||||
pub struct XNNPACK {
|
||||
options: ExecutionProviderOptions
|
||||
}
|
||||
|
||||
super::impl_ep!(arbitrary; XNNPACKExecutionProvider);
|
||||
super::impl_ep!(arbitrary; XNNPACK);
|
||||
|
||||
impl XNNPACKExecutionProvider {
|
||||
impl XNNPACK {
|
||||
/// Configures the number of threads to use for XNNPACK's internal intra-op threadpool.
|
||||
///
|
||||
/// ```
|
||||
/// # use core::num::NonZeroUsize;
|
||||
/// # use ort::{execution_providers::xnnpack::XNNPACKExecutionProvider, session::Session};
|
||||
/// # use ort::{ep, session::Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let ep = XNNPACKExecutionProvider::default()
|
||||
/// .with_intra_op_num_threads(NonZeroUsize::new(4).unwrap())
|
||||
/// .build();
|
||||
/// let ep = ep::XNNPACK::default().with_intra_op_num_threads(NonZeroUsize::new(4).unwrap()).build();
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -52,7 +50,7 @@ impl XNNPACKExecutionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider for XNNPACKExecutionProvider {
|
||||
impl ExecutionProvider for XNNPACK {
|
||||
fn name(&self) -> &'static str {
|
||||
"XnnpackExecutionProvider"
|
||||
}
|
||||
@@ -37,7 +37,7 @@ use crate::{
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ort::{
|
||||
/// # execution_providers::CUDAExecutionProvider,
|
||||
/// # ep,
|
||||
/// # io_binding::IoBinding,
|
||||
/// # memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType},
|
||||
/// # session::Session,
|
||||
@@ -45,10 +45,10 @@ use crate::{
|
||||
/// # };
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut text_encoder = Session::builder()?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
/// .with_execution_providers([ep::CUDA::default().build()])?
|
||||
/// .commit_from_file("text_encoder.onnx")?;
|
||||
/// let mut unet = Session::builder()?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
/// .with_execution_providers([ep::CUDA::default().build()])?
|
||||
/// .commit_from_file("unet.onnx")?;
|
||||
///
|
||||
/// let text_condition = text_encoder
|
||||
|
||||
10
src/lib.rs
10
src/lib.rs
@@ -3,6 +3,7 @@
|
||||
#![allow(clippy::tabs_in_doc_comments, clippy::arc_with_non_send_sync)]
|
||||
#![allow(clippy::macro_metavars_in_unsafe)]
|
||||
#![warn(clippy::unwrap_used)]
|
||||
#![deny(clippy::std_instead_of_alloc, clippy::std_instead_of_core)]
|
||||
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
|
||||
|
||||
//! <div align=center>
|
||||
@@ -28,8 +29,8 @@ pub mod adapter;
|
||||
pub mod compiler;
|
||||
pub mod editor;
|
||||
pub mod environment;
|
||||
pub mod ep;
|
||||
pub mod error;
|
||||
pub mod execution_providers;
|
||||
pub mod io_binding;
|
||||
pub mod logging;
|
||||
pub mod memory;
|
||||
@@ -48,6 +49,13 @@ pub mod api {
|
||||
pub use super::{api as ort, compiler::compile_api as compile, editor::editor_api as editor};
|
||||
}
|
||||
|
||||
#[deprecated = "import execution providers from `ort::ep` instead"]
|
||||
#[doc(hidden)]
|
||||
pub mod execution_providers {
|
||||
#[deprecated = "import execution providers from `ort::ep` instead"]
|
||||
pub use super::ep::*;
|
||||
}
|
||||
|
||||
use alloc::{borrow::ToOwned, boxed::Box, string::String};
|
||||
use core::{
|
||||
ffi::{CStr, c_char},
|
||||
|
||||
@@ -36,7 +36,8 @@ pub trait Operator: Send {
|
||||
/// Returns the name of the operator.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`.
|
||||
/// Returns the internal name of the execution provider this operator runs on, e.g. `"CUDAExecutionProvider"` (see
|
||||
/// [`ExecutionProvider::name`](crate::ep::ExecutionProvider::name)).
|
||||
///
|
||||
/// If the returned type is not `None`, and the execution provider used by the session matches this operator's
|
||||
/// EP type, the value will not be copied to the CPU and you may use functions like [`Tensor::data_ptr`] to
|
||||
|
||||
@@ -6,7 +6,6 @@ use core::{
|
||||
ffi::c_void,
|
||||
marker::PhantomData,
|
||||
mem::replace,
|
||||
ops::Deref,
|
||||
ptr::{self, NonNull}
|
||||
};
|
||||
#[cfg(feature = "std")]
|
||||
@@ -23,8 +22,8 @@ use crate::error::{Error, ErrorCode};
|
||||
use crate::util::OsCharArray;
|
||||
use crate::{
|
||||
AsPointer,
|
||||
ep::apply_execution_providers,
|
||||
error::Result,
|
||||
execution_providers::apply_execution_providers,
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
session::{InMemorySession, Input, Output, Session, SharedSessionInner, dangerous}
|
||||
@@ -136,7 +135,7 @@ impl SessionBuilder {
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
|
||||
fn commit_from_file_inner(mut self, model_path: &<OsCharArray as Deref>::Target) -> Result<Session> {
|
||||
fn commit_from_file_inner(mut self, model_path: &<OsCharArray as core::ops::Deref>::Target) -> Result<Session> {
|
||||
self.pre_commit()?;
|
||||
|
||||
let session_ptr = if let Some(prepacked_weights) = self.prepacked_weights.as_ref() {
|
||||
|
||||
@@ -13,8 +13,8 @@ use crate::util::path_to_os_char;
|
||||
use crate::{
|
||||
AsPointer, Error, ErrorCode,
|
||||
environment::{self, ThreadManager},
|
||||
ep::{ExecutionProviderDispatch, apply_execution_providers},
|
||||
error::Result,
|
||||
execution_providers::{ExecutionProviderDispatch, apply_execution_providers},
|
||||
logging::{LogLevel, LoggerFunction},
|
||||
memory::MemoryInfo,
|
||||
operator::OperatorDomain,
|
||||
@@ -34,8 +34,8 @@ impl SessionBuilder {
|
||||
/// ## Notes
|
||||
///
|
||||
/// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling
|
||||
/// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by
|
||||
/// providing a `Vec` of [`ExecutionProviderDispatch`]es.
|
||||
/// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec`
|
||||
/// of [`ExecutionProviderDispatch`]es.
|
||||
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result<Self> {
|
||||
apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?;
|
||||
Ok(self)
|
||||
|
||||
@@ -36,8 +36,8 @@ pub(crate) use self::{
|
||||
/// absolute path, then any subsequent requests to load a library called `foo` will use the `libfoo.so` we already
|
||||
/// loaded, instead of searching the system for a `foo` library.
|
||||
///
|
||||
/// See also [`crate::execution_providers::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all
|
||||
/// required dependencies of the [CUDA execution provider](crate::execution_providers::CUDAExecutionProvider).
|
||||
/// See also [`crate::ep::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all required dependencies
|
||||
/// of the [CUDA execution provider](crate::ep::CUDA).
|
||||
///
|
||||
/// ```
|
||||
/// use std::env;
|
||||
|
||||
@@ -6,7 +6,7 @@ use core::ops::{Deref, DerefMut};
|
||||
|
||||
use super::DefiniteTensorValueTypeMarker;
|
||||
use crate::{
|
||||
Error, OnceLock, Result, execution_providers as ep,
|
||||
Error, OnceLock, Result, ep,
|
||||
io_binding::IoBinding,
|
||||
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
|
||||
session::{NoSelectedOutputs, RunOptions, Session, builder::GraphOptimizationLevel},
|
||||
@@ -36,20 +36,20 @@ static IDENTITY_RUN_OPTIONS: OnceLock<RunOptions<NoSelectedOutputs>> = OnceLock:
|
||||
|
||||
fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result<ep::ExecutionProviderDispatch> {
|
||||
Ok(match device {
|
||||
AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(),
|
||||
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default()
|
||||
AllocationDevice::CPU => ep::CPU::default().with_arena_allocator(false).build(),
|
||||
AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDA::default()
|
||||
.with_device_id(device_id)
|
||||
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||
.with_conv_max_workspace(false)
|
||||
.with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default)
|
||||
.with_conv_algorithm_search(ep::cuda::ConvAlgorithmSearch::Default)
|
||||
.build(),
|
||||
AllocationDevice::DIRECTML => ep::DirectMLExecutionProvider::default().with_device_id(device_id).build(),
|
||||
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default()
|
||||
AllocationDevice::DIRECTML => ep::DirectML::default().with_device_id(device_id).build(),
|
||||
AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANN::default()
|
||||
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||
.with_cann_graph(false)
|
||||
.with_device_id(device_id)
|
||||
.build(),
|
||||
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default()
|
||||
AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINO::default()
|
||||
.with_num_threads(1)
|
||||
.with_device_type(if device == AllocationDevice::OPENVINO_CPU {
|
||||
"CPU".to_string()
|
||||
@@ -57,7 +57,7 @@ fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result<ep::Executi
|
||||
format!("GPU.{device_id}")
|
||||
})
|
||||
.build(),
|
||||
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCmExecutionProvider::default()
|
||||
AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCm::default()
|
||||
.with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested)
|
||||
.with_hip_graph(false)
|
||||
.with_exhaustive_conv_search(false)
|
||||
@@ -388,13 +388,13 @@ mod tests {
|
||||
#[cfg(feature = "cuda")]
|
||||
fn test_copy_into_cuda() -> crate::Result<()> {
|
||||
use crate::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
ep,
|
||||
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
|
||||
session::Session
|
||||
};
|
||||
|
||||
let dummy_session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_execution_providers([ep::CUDA::default().build()])?
|
||||
.commit_from_file("tests/data/upsample.ort")?;
|
||||
|
||||
let allocator = Allocator::new(&dummy_session, MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?)?;
|
||||
@@ -412,13 +412,13 @@ mod tests {
|
||||
#[cfg(feature = "cuda")]
|
||||
fn test_copy_into_async_cuda() -> crate::Result<()> {
|
||||
use crate::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
ep,
|
||||
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
|
||||
session::Session
|
||||
};
|
||||
|
||||
let dummy_session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_execution_providers([ep::CUDA::default().build()])?
|
||||
.commit_from_file("tests/data/upsample.ort")?;
|
||||
|
||||
let allocator = Allocator::new(&dummy_session, MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?)?;
|
||||
|
||||
@@ -213,7 +213,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> {
|
||||
/// ```
|
||||
///
|
||||
/// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be
|
||||
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
|
||||
/// returned. See [`ndarray::ArrayRef::as_standard_layout`] to convert an array to a contiguous layout.
|
||||
pub fn from_array_view(input: impl TensorArrayData<T> + 'a) -> Result<TensorRef<'a, T>> {
|
||||
let (shape, data, guard) = input.ref_parts()?;
|
||||
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
|
||||
@@ -252,7 +252,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
/// ```
|
||||
///
|
||||
/// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be
|
||||
/// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout.
|
||||
/// returned. See [`ndarray::ArrayRef::as_standard_layout`] to convert an array to a contiguous layout.
|
||||
pub fn from_array_view_mut(mut input: impl TensorArrayDataMut<T>) -> Result<TensorRefMut<'a, T>> {
|
||||
let (shape, data, guard) = input.ref_parts_mut()?;
|
||||
tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::<T>(), T::into_tensor_element_type(), guard).map(|tensor| {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use ort::{
|
||||
adapter::Adapter,
|
||||
execution_providers::CPUExecutionProvider,
|
||||
ep,
|
||||
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
|
||||
operator::{
|
||||
Operator, OperatorDomain,
|
||||
@@ -77,7 +77,7 @@ impl Operator for CustomOpTwo {
|
||||
}
|
||||
|
||||
fn main() -> ort::Result<()> {
|
||||
ort::init().with_execution_providers([CPUExecutionProvider::default().build()]).commit();
|
||||
ort::init().with_execution_providers([ep::CPU::default().build()]).commit();
|
||||
|
||||
let mut session = Session::builder()?
|
||||
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
|
||||
|
||||
Reference in New Issue
Block a user