diff --git a/docs/content/migrating/v2.mdx b/docs/content/migrating/v2.mdx index 39fbcdc..817ebee 100644 --- a/docs/content/migrating/v2.mdx +++ b/docs/content/migrating/v2.mdx @@ -12,7 +12,7 @@ To configure an `Environment`, you instead use the `ort::init` function, which r ```rust ort::init() - .with_execution_providers([CUDAExecutionProvider::default().build()]) + .with_execution_providers([ep::CUDA::default().build()]) .commit(); ``` @@ -156,7 +156,7 @@ let l = outputs["latents"].try_extract_array::()?; ``` ## Execution providers -Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/execution_providers/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information. +Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/ep/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information. ```diff -// v1.x @@ -165,7 +165,7 @@ Execution provider structs with public fields have been replaced with builder pa -}))?; +// v2 +builder = builder.with_execution_providers([ -+ DirectMLExecutionProvider::default() ++ ep::DirectML::default() + .with_device_id(1) + .build() +])?; diff --git a/docs/content/perf/execution-providers.mdx b/docs/content/perf/execution-providers.mdx index beb460c..ff0c7a7 100644 --- a/docs/content/perf/execution-providers.mdx +++ b/docs/content/perf/execution-providers.mdx @@ -52,11 +52,11 @@ ONNX Runtime must be compiled from source with support for each execution provid In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session: ```rust -use ort::{execution_providers::CUDAExecutionProvider, session::Session}; +use ort::{ep::CUDA, session::Session}; fn main() -> anyhow::Result<()> { let session = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default().build()])? + .with_execution_providers([CUDA::default().build()])? .commit_from_file("model.onnx")?; Ok(()) @@ -66,21 +66,18 @@ fn main() -> anyhow::Result<()> { You can, of course, specify multiple execution providers. `ort` will register all EPs specified, in order. If an EP does not support a certain operator in a graph, it will fall back to the next successfully registered EP, or to the CPU if all else fails. ```rust -use ort::{ - execution_providers::{CoreMLExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, TensorRTExecutionProvider}, - session::Session -}; +use ort::{ep, session::Session}; fn main() -> anyhow::Result<()> { let session = Session::builder()? .with_execution_providers([ // Prefer TensorRT over CUDA. - TensorRTExecutionProvider::default().build(), - CUDAExecutionProvider::default().build(), + ep::TensorRT::default().build(), + ep::CUDA::default().build(), // Use DirectML on Windows if NVIDIA EPs are not available - DirectMLExecutionProvider::default().build(), + ep::DirectML::default().build(), // Or use ANE on Apple platforms - CoreMLExecutionProvider::default().build() + ep::CoreML::default().build() ])? .commit_from_file("model.onnx")?; @@ -89,18 +86,18 @@ fn main() -> anyhow::Result<()> { ``` ## Configuring EPs -EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/execution_providers/index.html#reexports) for the EP structs for more information on which options are supported and what they do. +EPs have configuration options to control behavior or increase performance. Each execution provider struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.10/ort/ep/index.html#reexports) for the EP structs for more information on which options are supported and what they do. ```rust -use ort::{execution_providers::{coreml::CoreMLComputeUnits, CoreMLExecutionProvider}, session::Session}; +use ort::{ep, session::Session}; fn main() -> anyhow::Result<()> { let session = Session::builder()? .with_execution_providers([ - CoreMLExecutionProvider::default() + ep::CoreML::default() // this model uses control flow operators, so enable CoreML on subgraphs too .with_subgraphs() - .with_compute_units(CoreMLComputeUnits::CPUAndNeuralEngine) + .with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine) .build() ])? .commit_from_file("model.onnx")?; @@ -114,12 +111,12 @@ fn main() -> anyhow::Result<()> { You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`: ```rust -use ort::{execution_providers::CoreMLExecutionProvider, session::Session}; +use ort::{ep, session::Session}; fn main() -> anyhow::Result<()> { let session = Session::builder()? .with_execution_providers([ - CUDAExecutionProvider::default().build().error_on_failure() + ep::CUDA::default().build().error_on_failure() ])? .commit_from_file("model.onnx")?; @@ -131,14 +128,14 @@ If you require more complex error handling, you can also manually register execu ```rust use ort::{ - execution_providers::{CUDAExecutionProvider, ExecutionProvider}, + ep::{self, ExecutionProvider}, session::Session }; fn main() -> anyhow::Result<()> { let builder = Session::builder()?; - let cuda = CUDAExecutionProvider::default(); + let cuda = ep::CUDA::default(); if cuda.register(&builder).is_err() { eprintln!("Failed to register CUDA!"); std::process::exit(1); @@ -154,14 +151,14 @@ You can also check whether ONNX Runtime is even compiled with support for the ex ```rust use ort::{ - execution_providers::{CoreMLExecutionProvider, ExecutionProvider}, + ep::{self, ExecutionProvider}, session::Session }; fn main() -> anyhow::Result<()> { let builder = Session::builder()?; - let coreml = CoreMLExecutionProvider::default(); + let coreml = ep::CoreML::default(); if !coreml.is_available() { eprintln!("Please compile ONNX Runtime with CoreML!"); std::process::exit(1); @@ -180,11 +177,11 @@ fn main() -> anyhow::Result<()> { You can configure `ort` to attempt to register a list of execution providers for all sessions created in an environment. ```rust -use ort::{execution_providers::CUDAExecutionProvider, session::Session}; +use ort::{ep, session::Session}; fn main() -> anyhow::Result<()> { ort::init() - .with_execution_providers([CUDAExecutionProvider::default().build()]) + .with_execution_providers([ep::CUDA::default().build()]) .commit(); let session = Session::builder()?.commit_from_file("model.onnx")?; diff --git a/docs/content/perf/io-binding.mdx b/docs/content/perf/io-binding.mdx index 4221f11..a08496b 100644 --- a/docs/content/perf/io-binding.mdx +++ b/docs/content/perf/io-binding.mdx @@ -80,10 +80,10 @@ Here is a more complete example of the I/O binding API in a scenario where I/O p ```rs let mut text_encoder = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default().build()])? + .with_execution_providers([ep::CUDA::default().build()])? .commit_from_file("text_encoder.onnx")?; let mut unet = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default().build()])? + .with_execution_providers([ep::CUDA::default().build()])? .commit_from_file("unet.onnx")?; let text_condition = { diff --git a/examples/common/mod.rs b/examples/common/mod.rs index a1bb6e9..d5a0d87 100644 --- a/examples/common/mod.rs +++ b/examples/common/mod.rs @@ -1,5 +1,5 @@ #[allow(unused)] -use ort::execution_providers::*; +use ort::ep::*; pub fn init() -> ort::Result<()> { #[cfg(feature = "backend-candle")] @@ -11,41 +11,41 @@ pub fn init() -> ort::Result<()> { ort::init() .with_execution_providers([ #[cfg(feature = "tensorrt")] - TensorRTExecutionProvider::default().build(), + TensorRT::default().build(), #[cfg(feature = "cuda")] - CUDAExecutionProvider::default().build(), + CUDA::default().build(), #[cfg(feature = "onednn")] - OneDNNExecutionProvider::default().build(), + OneDNN::default().build(), #[cfg(feature = "acl")] - ACLExecutionProvider::default().build(), + ACL::default().build(), #[cfg(feature = "openvino")] - OpenVINOExecutionProvider::default().build(), + OpenVINO::default().build(), #[cfg(feature = "coreml")] - CoreMLExecutionProvider::default().build(), + CoreML::default().build(), #[cfg(feature = "rocm")] - ROCmExecutionProvider::default().build(), + ROCm::default().build(), #[cfg(feature = "cann")] - CANNExecutionProvider::default().build(), + CANN::default().build(), #[cfg(feature = "directml")] - DirectMLExecutionProvider::default().build(), + DirectML::default().build(), #[cfg(feature = "tvm")] - TVMExecutionProvider::default().build(), + TVM::default().build(), #[cfg(feature = "nnapi")] - NNAPIExecutionProvider::default().build(), + NNAPI::default().build(), #[cfg(feature = "qnn")] - QNNExecutionProvider::default().build(), + QNN::default().build(), #[cfg(feature = "xnnpack")] - XNNPACKExecutionProvider::default().build(), + XNNPACK::default().build(), #[cfg(feature = "armnn")] - ArmNNExecutionProvider::default().build(), + ArmNN::default().build(), #[cfg(feature = "migraphx")] - MIGraphXExecutionProvider::default().build(), + MIGraphX::default().build(), #[cfg(feature = "vitis")] - VitisAIExecutionProvider::default().build(), + VitisAI::default().build(), #[cfg(feature = "rknpu")] - RKNPUExecutionProvider::default().build(), + RKNPU::default().build(), #[cfg(feature = "webgpu")] - WebGPUExecutionProvider::default().build() + WebGPU::default().build() ]) .commit(); diff --git a/examples/cudarc/cudarc.rs b/examples/cudarc/cudarc.rs index c1cce73..6657057 100644 --- a/examples/cudarc/cudarc.rs +++ b/examples/cudarc/cudarc.rs @@ -4,7 +4,7 @@ use cudarc::driver::{CudaDevice, DevicePtr}; use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType}; use ndarray::Array; use ort::{ - execution_providers::CUDAExecutionProvider, + ep, memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType}, session::Session, tensor::Shape, @@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> { #[rustfmt::skip] ort::init() .with_execution_providers([ - CUDAExecutionProvider::default() + ep::CUDA::default() .build() // exit the program with an error if the CUDA EP fails to register .error_on_failure() diff --git a/examples/training/train-clm-simple.rs b/examples/training/train-clm-simple.rs index de3b59c..6bee64a 100644 --- a/examples/training/train-clm-simple.rs +++ b/examples/training/train-clm-simple.rs @@ -6,7 +6,7 @@ use std::{ use kdam::BarExt; use ort::{ - execution_providers::CUDAExecutionProvider, + ep, memory::Allocator, session::{Session, builder::SessionBuilder}, training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments}, @@ -56,7 +56,7 @@ fn main() -> ort::Result<()> { common::init()?; let trainer = Trainer::new_from_artifacts( - SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?, Allocator::default(), "tools/train-data/mini-clm", None diff --git a/examples/training/train-clm.rs b/examples/training/train-clm.rs index b838902..97c5a7d 100644 --- a/examples/training/train-clm.rs +++ b/examples/training/train-clm.rs @@ -6,7 +6,7 @@ use std::{ use kdam::BarExt; use ort::{ - execution_providers::CUDAExecutionProvider, + ep, memory::Allocator, session::{Session, builder::SessionBuilder}, training::{Checkpoint, Trainer}, @@ -38,7 +38,7 @@ fn main() -> ort::Result<()> { let _ = kdam::term::hide_cursor(); let trainer = Trainer::new( - SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?, Allocator::default(), Checkpoint::load("tools/train-data/mini-clm/checkpoint")?, "tools/train-data/mini-clm/training_model.onnx", diff --git a/examples/wasm-emscripten/src/main.rs b/examples/wasm-emscripten/src/main.rs index 2c16bf4..58c7818 100644 --- a/examples/wasm-emscripten/src/main.rs +++ b/examples/wasm-emscripten/src/main.rs @@ -57,8 +57,8 @@ pub extern "C" fn detect_objects(ptr: *const u8, width: u32, height: u32) { let use_webgpu = true; // TODO: Make `use_webgpu` a parameter of `detect_objects`? Or say in README to change it here. if use_webgpu { - use ort::execution_providers::ExecutionProvider; - let ep = ort::execution_providers::WebGPUExecutionProvider::default(); + use ort::ep::ExecutionProvider; + let ep = ort::ep::WebGPUExecutionProvider::default(); if ep.is_available().expect("Cannot check for availability of WebGPU ep.") { ep.register(&mut builder).expect("Cannot register WebGPU ep."); } else { diff --git a/src/adapter.rs b/src/adapter.rs index 72fe7ce..c06f19a 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -63,14 +63,14 @@ impl Adapter { /// ``` /// # use ort::{ /// # adapter::Adapter, - /// # execution_providers::CUDAExecutionProvider, + /// # ep, /// # memory::DeviceType, /// # session::{run_options::RunOptions, Session}, /// # value::Tensor /// # }; /// # fn main() -> ort::Result<()> { /// let mut model = Session::builder()? - /// .with_execution_providers([CUDAExecutionProvider::default().build()])? + /// .with_execution_providers([ep::CUDA::default().build()])? /// .commit_from_file("tests/data/lora_model.onnx")?; /// /// let allocator = model.allocator(); @@ -108,14 +108,14 @@ impl Adapter { /// ``` /// # use ort::{ /// # adapter::Adapter, - /// # execution_providers::CUDAExecutionProvider, + /// # ep, /// # memory::DeviceType, /// # session::{run_options::RunOptions, Session}, /// # value::Tensor /// # }; /// # fn main() -> ort::Result<()> { /// let mut model = Session::builder()? - /// .with_execution_providers([CUDAExecutionProvider::default().build()])? + /// .with_execution_providers([ep::CUDA::default().build()])? /// .commit_from_file("tests/data/lora_model.onnx")?; /// /// let bytes = std::fs::read("tests/data/adapter.orl").unwrap(); diff --git a/src/environment.rs b/src/environment.rs index 59d90aa..9dcaefc 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -2,9 +2,9 @@ //! //! Environments can be configured via [`ort::init`](init): //! ``` -//! # use ort::execution_providers::CUDAExecutionProvider; +//! # use ort::ep; //! # fn main() -> ort::Result<()> { -//! ort::init().with_execution_providers([CUDAExecutionProvider::default().build()]).commit(); +//! ort::init().with_execution_providers([ep::CUDA::default().build()]).commit(); //! # Ok(()) //! # } //! ``` @@ -25,8 +25,8 @@ use smallvec::SmallVec; use crate::{ AsPointer, + ep::ExecutionProviderDispatch, error::Result, - execution_providers::ExecutionProviderDispatch, logging::{LogLevel, LoggerFunction}, ortsys, util::{Mutex, OnceLock, STACK_EXECUTION_PROVIDERS, with_cstr} @@ -430,9 +430,9 @@ impl EnvironmentBuilder { /// Creates an ONNX Runtime environment. /// /// ``` -/// # use ort::execution_providers::CUDAExecutionProvider; +/// # use ort::ep; /// # fn main() -> ort::Result<()> { -/// ort::init().with_execution_providers([CUDAExecutionProvider::default().build()]).commit(); +/// ort::init().with_execution_providers([ep::CUDA::default().build()]).commit(); /// # Ok(()) /// # } /// ``` @@ -456,11 +456,11 @@ pub fn init() -> EnvironmentBuilder { /// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded. /// /// ```no_run -/// # use ort::execution_providers::CUDAExecutionProvider; +/// # use ort::ep; /// # fn main() -> ort::Result<()> { /// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib"); /// ort::init_from(lib_path.join("onnxruntime.dll"))? -/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .with_execution_providers([ep::CUDA::default().build()]) /// .commit(); /// # Ok(()) /// # } diff --git a/src/execution_providers/acl.rs b/src/ep/acl.rs similarity index 81% rename from src/execution_providers/acl.rs rename to src/ep/acl.rs index 5b4880e..9089456 100644 --- a/src/execution_providers/acl.rs +++ b/src/ep/acl.rs @@ -4,20 +4,20 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// [Arm Compute Library execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/ACL-ExecutionProvider.html) /// for ARM platforms. #[derive(Debug, Default, Clone)] -pub struct ACLExecutionProvider { +pub struct ACL { fast_math: bool } -super::impl_ep!(ACLExecutionProvider); +super::impl_ep!(ACL); -impl ACLExecutionProvider { +impl ACL { /// Enable/disable ACL's fast math mode. Enabling can improve performance at the cost of some accuracy for /// `MatMul`/`Conv` nodes. /// /// ``` - /// # use ort::{execution_providers::acl::ACLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = ACLExecutionProvider::default().with_fast_math(true).build(); + /// let ep = ep::ACL::default().with_fast_math(true).build(); /// # Ok(()) /// # } /// ``` @@ -28,7 +28,7 @@ impl ACLExecutionProvider { } } -impl ExecutionProvider for ACLExecutionProvider { +impl ExecutionProvider for ACL { fn name(&self) -> &'static str { "ACLExecutionProvider" } diff --git a/src/execution_providers/armnn.rs b/src/ep/armnn.rs similarity index 80% rename from src/execution_providers/armnn.rs rename to src/ep/armnn.rs index 2cd99c2..a18e37b 100644 --- a/src/execution_providers/armnn.rs +++ b/src/ep/armnn.rs @@ -4,19 +4,19 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// [Arm NN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/ArmNN-ExecutionProvider.html) /// for ARM platforms. #[derive(Debug, Default, Clone)] -pub struct ArmNNExecutionProvider { +pub struct ArmNN { use_arena: bool } -super::impl_ep!(ArmNNExecutionProvider); +super::impl_ep!(ArmNN); -impl ArmNNExecutionProvider { +impl ArmNN { /// Enable/disable the usage of the arena allocator. /// /// ``` - /// # use ort::{execution_providers::armnn::ArmNNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = ArmNNExecutionProvider::default().with_arena_allocator(true).build(); + /// let ep = ep::ArmNN::default().with_arena_allocator(true).build(); /// # Ok(()) /// # } /// ``` @@ -27,7 +27,7 @@ impl ArmNNExecutionProvider { } } -impl ExecutionProvider for ArmNNExecutionProvider { +impl ExecutionProvider for ArmNN { fn name(&self) -> &'static str { "ArmNNExecutionProvider" } diff --git a/src/execution_providers/azure.rs b/src/ep/azure.rs similarity index 91% rename from src/execution_providers/azure.rs rename to src/ep/azure.rs index 3454555..95df447 100644 --- a/src/execution_providers/azure.rs +++ b/src/ep/azure.rs @@ -35,12 +35,12 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// /// To use this model in `ort`: /// ```no_run -/// # use ort::{execution_providers::azure::AzureExecutionProvider, session::Session, value::Tensor}; +/// # use ort::{ep, session::Session, value::Tensor}; /// # fn main() -> ort::Result<()> { /// let mut session = Session::builder()? /// // note: session must be initialized with `onnxruntime-extensions` /// .with_extensions()? -/// .with_execution_providers([AzureExecutionProvider::default().build()])? +/// .with_execution_providers([ep::Azure::default().build()])? /// .commit_from_file("azure_chat.onnx")?; /// /// let auth_token = Tensor::from_string_array(([1], &*vec!["..."]))?; @@ -65,13 +65,13 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// # } /// ``` #[derive(Debug, Default, Clone)] -pub struct AzureExecutionProvider { +pub struct Azure { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; AzureExecutionProvider); +super::impl_ep!(arbitrary; Azure); -impl ExecutionProvider for AzureExecutionProvider { +impl ExecutionProvider for Azure { fn name(&self) -> &'static str { "AzureExecutionProvider" } diff --git a/src/execution_providers/cann.rs b/src/ep/cann.rs similarity index 67% rename from src/execution_providers/cann.rs rename to src/ep/cann.rs index 7a2efcf..2bc9689 100644 --- a/src/execution_providers/cann.rs +++ b/src/ep/cann.rs @@ -5,7 +5,7 @@ use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] -pub enum CANNPrecisionMode { +pub enum PrecisionMode { /// Convert to float32 first according to operator implementation ForceFP32, /// Convert to float16 when float16 and float32 are both supported @@ -20,7 +20,7 @@ pub enum CANNPrecisionMode { #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] -pub enum CANNImplementationMode { +pub enum ImplementationMode { /// Prefer high precision, potentially at the cost of some performance. HighPrecision, /// Prefer high performance, potentially with lower accuracy. @@ -30,19 +30,19 @@ pub enum CANNImplementationMode { /// [CANN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html) /// for hardware acceleration using Huawei Ascend AI devices. #[derive(Default, Debug, Clone)] -pub struct CANNExecutionProvider { +pub struct CANN { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; CANNExecutionProvider); +super::impl_ep!(arbitrary; CANN); -impl CANNExecutionProvider { +impl CANN { /// Configures which device the EP should use. /// /// ``` - /// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_device_id(0).build(); + /// let ep = ep::CANN::default().with_device_id(0).build(); /// # Ok(()) /// # } /// ``` @@ -56,9 +56,9 @@ impl CANNExecutionProvider { /// provider’s arena; the total device memory usage may be higher. /// /// ``` - /// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_memory_limit(2 * 1024 * 1024 * 1024).build(); + /// let ep = ep::CANN::default().with_memory_limit(2 * 1024 * 1024 * 1024).build(); /// # Ok(()) /// # } /// ``` @@ -71,9 +71,9 @@ impl CANNExecutionProvider { /// Configure the strategy for extending the device's memory arena. /// /// ``` - /// # use ort::{execution_providers::{cann::CANNExecutionProvider, ArenaExtendStrategy}, session::Session}; + /// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default() + /// let ep = ep::CANN::default() /// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested) /// .build(); /// # Ok(()) @@ -95,9 +95,9 @@ impl CANNExecutionProvider { /// is `true`. If `false`, it will fall back to the single-operator inference engine. /// /// ``` - /// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_cann_graph(true).build(); + /// let ep = ep::CANN::default().with_cann_graph(true).build(); /// # Ok(()) /// # } /// ``` @@ -110,9 +110,9 @@ impl CANNExecutionProvider { /// Configure whether to dump the subgraph into ONNX format for analysis of subgraph segmentation. /// /// ``` - /// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_dump_graphs(true).build(); + /// let ep = ep::CANN::default().with_dump_graphs(true).build(); /// # Ok(()) /// # } /// ``` @@ -125,9 +125,9 @@ impl CANNExecutionProvider { /// Configure whether to dump the offline model to an `.om` file. /// /// ``` - /// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_dump_om_model(true).build(); + /// let ep = ep::CANN::default().with_dump_om_model(true).build(); /// # Ok(()) /// # } /// ``` @@ -137,25 +137,25 @@ impl CANNExecutionProvider { self } - /// Configure the precision mode; see [`CANNPrecisionMode`]. + /// Configure the precision mode; see [`PrecisionMode`]. /// /// ``` - /// # use ort::{execution_providers::cann::{CANNExecutionProvider, CANNPrecisionMode}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_precision_mode(CANNPrecisionMode::ForceFP16).build(); + /// let ep = ep::CANN::default().with_precision_mode(ep::cann::PrecisionMode::ForceFP16).build(); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn with_precision_mode(mut self, mode: CANNPrecisionMode) -> Self { + pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self { self.options.set( "precision_mode", match mode { - CANNPrecisionMode::ForceFP32 => "force_fp32", - CANNPrecisionMode::ForceFP16 => "force_fp16", - CANNPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16", - CANNPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype", - CANNPrecisionMode::AllowMixedPrecision => "allow_mix_precision" + PrecisionMode::ForceFP32 => "force_fp32", + PrecisionMode::ForceFP16 => "force_fp16", + PrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16", + PrecisionMode::MustKeepOrigin => "must_keep_origin_dtype", + PrecisionMode::AllowMixedPrecision => "allow_mix_precision" } ); self @@ -165,33 +165,33 @@ impl CANNExecutionProvider { /// high-performance implementations. /// /// ``` - /// # use ort::{execution_providers::cann::{CANNExecutionProvider, CANNImplementationMode}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default() - /// .with_implementation_mode(CANNImplementationMode::HighPerformance) + /// let ep = ep::CANN::default() + /// .with_implementation_mode(ep::cann::ImplementationMode::HighPerformance) /// .build(); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn with_implementation_mode(mut self, mode: CANNImplementationMode) -> Self { + pub fn with_implementation_mode(mut self, mode: ImplementationMode) -> Self { self.options.set( "op_select_impl_mode", match mode { - CANNImplementationMode::HighPrecision => "high_precision", - CANNImplementationMode::HighPerformance => "high_performance" + ImplementationMode::HighPrecision => "high_precision", + ImplementationMode::HighPerformance => "high_performance" } ); self } /// Configure the list of operators which use the mode specified by - /// [`CANNExecutionProvider::with_implementation_mode`]. + /// [`CANN::with_implementation_mode`]. /// /// ``` - /// # use ort::{execution_providers::cann::CANNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CANNExecutionProvider::default().with_implementation_mode_oplist("LayerNorm,Gelu").build(); + /// let ep = ep::CANN::default().with_implementation_mode_oplist("LayerNorm,Gelu").build(); /// # Ok(()) /// # } /// ``` @@ -202,7 +202,7 @@ impl CANNExecutionProvider { } } -impl ExecutionProvider for CANNExecutionProvider { +impl ExecutionProvider for CANN { fn name(&self) -> &'static str { "CANNExecutionProvider" } diff --git a/src/execution_providers/coreml.rs b/src/ep/coreml.rs similarity index 73% rename from src/execution_providers/coreml.rs rename to src/ep/coreml.rs index e53fea8..8b0a539 100644 --- a/src/execution_providers/coreml.rs +++ b/src/ep/coreml.rs @@ -4,7 +4,7 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CoreMLSpecializationStrategy { +pub enum SpecializationStrategy { /// The strategy that should work well for most applications. Default, /// Prefer the prediction latency at the potential cost of specialization time, memory footprint, and the disk space @@ -12,7 +12,7 @@ pub enum CoreMLSpecializationStrategy { FastPrediction } -impl CoreMLSpecializationStrategy { +impl SpecializationStrategy { pub(crate) fn as_str(&self) -> &'static str { match self { Self::Default => "Default", @@ -22,7 +22,7 @@ impl CoreMLSpecializationStrategy { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CoreMLComputeUnits { +pub enum ComputeUnits { /// Enable CoreML EP for all compatible Apple devices. All, /// Enable CoreML EP for Apple devices with a compatible Neural Engine (ANE). @@ -33,7 +33,7 @@ pub enum CoreMLComputeUnits { CPUOnly } -impl CoreMLComputeUnits { +impl ComputeUnits { pub(crate) fn as_str(&self) -> &'static str { match self { Self::All => "ALL", @@ -45,14 +45,14 @@ impl CoreMLComputeUnits { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CoreMLModelFormat { +pub enum ModelFormat { /// Requires Core ML 5 or later (iOS 15+ or macOS 12+). MLProgram, /// Default; requires Core ML 3 or later (iOS 13+ or macOS 10.15+). NeuralNetwork } -impl CoreMLModelFormat { +impl ModelFormat { pub(crate) fn as_str(&self) -> &'static str { match self { Self::MLProgram => "MLProgram", @@ -64,20 +64,20 @@ impl CoreMLModelFormat { /// [CoreML execution provider](https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html) for hardware /// acceleration on Apple devices. #[derive(Debug, Default, Clone)] -pub struct CoreMLExecutionProvider { +pub struct CoreML { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; CoreMLExecutionProvider); +super::impl_ep!(arbitrary; CoreML); -impl CoreMLExecutionProvider { +impl CoreML { /// Enable CoreML EP to run on a subgraph in the body of a control flow operator (i.e. a `Loop`, `Scan` or `If` /// operator). /// /// ``` - /// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_subgraphs(true).build(); + /// let ep = ep::CoreML::default().with_subgraphs(true).build(); /// # Ok(()) /// # } /// ``` @@ -91,9 +91,9 @@ impl CoreMLExecutionProvider { /// allow inputs with dynamic shapes, however performance may be negatively impacted by inputs with dynamic shapes. /// /// ``` - /// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_static_input_shapes(true).build(); + /// let ep = ep::CoreML::default().with_static_input_shapes(true).build(); /// # Ok(()) /// # } /// ``` @@ -105,19 +105,19 @@ impl CoreMLExecutionProvider { /// Configures the format of the CoreML model created by the EP. /// - /// The default format, [NeuralNetwork](`CoreMLModelFormat::NeuralNetwork`), has better compatibility with older - /// versions of macOS/iOS. The newer [MLProgram](`CoreMLModelFormat::MLProgram`) format supports more operators, + /// The default format, [NeuralNetwork](`ModelFormat::NeuralNetwork`), has better compatibility with older + /// versions of macOS/iOS. The newer [MLProgram](`ModelFormat::MLProgram`) format supports more operators, /// and may be more performant. /// /// ``` - /// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLModelFormat}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_model_format(CoreMLModelFormat::MLProgram).build(); + /// let ep = ep::CoreML::default().with_model_format(ep::coreml::ModelFormat::MLProgram).build(); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn with_model_format(mut self, model_format: CoreMLModelFormat) -> Self { + pub fn with_model_format(mut self, model_format: ModelFormat) -> Self { self.options.set("ModelFormat", model_format.as_str()); self } @@ -129,14 +129,16 @@ impl CoreMLExecutionProvider { /// model for faster prediction, at the potential cost of session load time and memory footprint. /// /// ``` - /// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLSpecializationStrategy}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_specialization_strategy(CoreMLSpecializationStrategy::FastPrediction).build(); + /// let ep = ep::CoreML::default() + /// .with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction) + /// .build(); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn with_specialization_strategy(mut self, strategy: CoreMLSpecializationStrategy) -> Self { + pub fn with_specialization_strategy(mut self, strategy: SpecializationStrategy) -> Self { self.options.set("SpecializationStrategy", strategy.as_str()); self } @@ -144,16 +146,16 @@ impl CoreMLExecutionProvider { /// Configures what hardware can be used by CoreML for acceleration. /// /// ``` - /// # use ort::{execution_providers::coreml::{CoreMLExecutionProvider, CoreMLComputeUnits}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default() - /// .with_compute_units(CoreMLComputeUnits::CPUAndNeuralEngine) + /// let ep = ep::CoreML::default() + /// .with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine) /// .build(); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn with_compute_units(mut self, units: CoreMLComputeUnits) -> Self { + pub fn with_compute_units(mut self, units: ComputeUnits) -> Self { self.options.set("MLComputeUnits", units.as_str()); self } @@ -162,9 +164,9 @@ impl CoreMLExecutionProvider { /// for debugging unexpected performance with CoreML. /// /// ``` - /// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_profile_compute_plan(true).build(); + /// let ep = ep::CoreML::default().with_profile_compute_plan(true).build(); /// # Ok(()) /// # } /// ``` @@ -177,9 +179,9 @@ impl CoreMLExecutionProvider { /// Configures whether to allow low-precision (fp16) accumulation on GPU. /// /// ``` - /// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_low_precision_accumulation_on_gpu(true).build(); + /// let ep = ep::CoreML::default().with_low_precision_accumulation_on_gpu(true).build(); /// # Ok(()) /// # } /// ``` @@ -195,9 +197,9 @@ impl CoreMLExecutionProvider { /// session. Setting this option allows the compiled model to be reused across session loads. /// /// ``` - /// # use ort::{execution_providers::coreml::CoreMLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CoreMLExecutionProvider::default().with_model_cache_dir("/path/to/cache").build(); + /// let ep = ep::CoreML::default().with_model_cache_dir("/path/to/cache").build(); /// # Ok(()) /// # } /// ``` @@ -235,7 +237,7 @@ impl CoreMLExecutionProvider { } } -impl ExecutionProvider for CoreMLExecutionProvider { +impl ExecutionProvider for CoreML { fn name(&self) -> &'static str { "CoreMLExecutionProvider" } diff --git a/src/execution_providers/cpu.rs b/src/ep/cpu.rs similarity index 75% rename from src/execution_providers/cpu.rs rename to src/ep/cpu.rs index 57310bb..2ed7e34 100644 --- a/src/execution_providers/cpu.rs +++ b/src/ep/cpu.rs @@ -3,19 +3,19 @@ use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder}; /// The default CPU execution provider, powered by MLAS. #[derive(Debug, Default, Clone)] -pub struct CPUExecutionProvider { +pub struct CPU { use_arena: bool } -super::impl_ep!(CPUExecutionProvider); +super::impl_ep!(CPU); -impl CPUExecutionProvider { +impl CPU { /// Enable/disable the usage of the arena allocator. /// /// ``` - /// # use ort::{execution_providers::cpu::CPUExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CPUExecutionProvider::default().with_arena_allocator(true).build(); + /// let ep = ep::CPU::default().with_arena_allocator(true).build(); /// # Ok(()) /// # } /// ``` @@ -26,7 +26,7 @@ impl CPUExecutionProvider { } } -impl ExecutionProvider for CPUExecutionProvider { +impl ExecutionProvider for CPU { fn name(&self) -> &'static str { "CPUExecutionProvider" } diff --git a/src/execution_providers/cuda.rs b/src/ep/cuda.rs similarity index 82% rename from src/execution_providers/cuda.rs rename to src/ep/cuda.rs index 95dcbb2..b92bff9 100644 --- a/src/execution_providers/cuda.rs +++ b/src/ep/cuda.rs @@ -7,9 +7,9 @@ use crate::{error::Result, session::builder::SessionBuilder}; // https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/contrib_ops/cpu/bert/attention_common.h#L160-L171 #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] -pub struct CUDAAttentionBackend(u32); +pub struct AttentionBackend(u32); -impl CUDAAttentionBackend { +impl AttentionBackend { pub const FLASH_ATTENTION: Self = Self(1 << 0); pub const EFFICIENT_ATTENTION: Self = Self(1 << 1); pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2); @@ -23,7 +23,7 @@ impl CUDAAttentionBackend { pub const LEAN_ATTENTION: Self = Self(1 << 8); pub fn none() -> Self { - CUDAAttentionBackend(0) + AttentionBackend(0) } pub fn all() -> Self { @@ -38,7 +38,7 @@ impl CUDAAttentionBackend { } } -impl BitOr for CUDAAttentionBackend { +impl BitOr for AttentionBackend { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { Self(rhs.0 | self.0) @@ -47,7 +47,7 @@ impl BitOr for CUDAAttentionBackend { /// The type of search done for cuDNN convolution algorithms. #[derive(Debug, Clone, Default)] -pub enum CuDNNConvAlgorithmSearch { +pub enum ConvAlgorithmSearch { /// Expensive exhaustive benchmarking using [`cudnnFindConvolutionForwardAlgorithmEx`][exhaustive]. /// This function will attempt all possible algorithms for `cudnnConvolutionForward` to find the fastest algorithm. /// Exhaustive search trades off between memory usage and speed. The first execution of a graph will be slow while @@ -71,27 +71,27 @@ pub enum CuDNNConvAlgorithmSearch { /// > search algorithm is actually [`Exhaustive`]. /// /// [fwdalgo]: https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionFwdAlgo_t - /// [`Exhaustive`]: CuDNNConvAlgorithmSearch::Exhaustive - /// [`Heuristic`]: CuDNNConvAlgorithmSearch::Heuristic + /// [`Exhaustive`]: ConvAlgorithmSearch::Exhaustive + /// [`Heuristic`]: ConvAlgorithmSearch::Heuristic Default } /// [CUDA execution provider](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for NVIDIA /// CUDA-enabled GPUs. #[derive(Debug, Default, Clone)] -pub struct CUDAExecutionProvider { +pub struct CUDA { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; CUDAExecutionProvider); +super::impl_ep!(arbitrary; CUDA); -impl CUDAExecutionProvider { +impl CUDA { /// Configures which device the EP should use. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_device_id(0).build(); + /// let ep = ep::CUDA::default().with_device_id(0).build(); /// # Ok(()) /// # } /// ``` @@ -104,12 +104,12 @@ impl CUDAExecutionProvider { /// Configure the size limit of the device memory arena in bytes. /// /// This only controls how much memory can be allocated to the *arena* - actual memory usage may be higher due to - /// internal CUDA allocations, like those required for different [`CuDNNConvAlgorithmSearch`] options. + /// internal CUDA allocations, like those required for different [`ConvAlgorithmSearch`] options. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_memory_limit(2 * 1024 * 1024 * 1024).build(); + /// let ep = ep::CUDA::default().with_memory_limit(2 * 1024 * 1024 * 1024).build(); /// # Ok(()) /// # } /// ``` @@ -122,9 +122,9 @@ impl CUDAExecutionProvider { /// Configure the strategy for extending the device's memory arena. /// /// ``` - /// # use ort::{execution_providers::{cuda::CUDAExecutionProvider, ArenaExtendStrategy}, session::Session}; + /// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default() + /// let ep = ep::CUDA::default() /// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested) /// .build(); /// # Ok(()) @@ -152,7 +152,7 @@ impl CUDAExecutionProvider { /// The default search algorithm, [`Exhaustive`][exh], will benchmark all available implementations and use the most /// performant one. This option is very resource intensive (both computationally on first run and peak-memory-wise), /// but ensures best performance. It is roughly equivalent to setting `torch.backends.cudnn.benchmark = True` with - /// PyTorch. See also [`CUDAExecutionProvider::with_conv_max_workspace`] to configure how much memory the exhaustive + /// PyTorch. See also [`CUDA::with_conv_max_workspace`] to configure how much memory the exhaustive /// search can use (the default is unlimited). /// /// A less resource-intensive option is [`Heuristic`][heu]. Rather than benchmarking every implementation, @@ -164,41 +164,41 @@ impl CUDAExecutionProvider { /// is not the *default behavior* (that would be [`Exhaustive`][exh]). /// /// ``` - /// # use ort::{execution_providers::cuda::{CUDAExecutionProvider, CuDNNConvAlgorithmSearch}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default() - /// .with_conv_algorithm_search(CuDNNConvAlgorithmSearch::Heuristic) + /// let ep = ep::CUDA::default() + /// .with_conv_algorithm_search(ep::cuda::ConvAlgorithmSearch::Heuristic) /// .build(); /// # Ok(()) /// # } /// ``` /// - /// [exh]: CuDNNConvAlgorithmSearch::Exhaustive - /// [heu]: CuDNNConvAlgorithmSearch::Heuristic - /// [def]: CuDNNConvAlgorithmSearch::Default + /// [exh]: ConvAlgorithmSearch::Exhaustive + /// [heu]: ConvAlgorithmSearch::Heuristic + /// [def]: ConvAlgorithmSearch::Default #[must_use] - pub fn with_conv_algorithm_search(mut self, search: CuDNNConvAlgorithmSearch) -> Self { + pub fn with_conv_algorithm_search(mut self, search: ConvAlgorithmSearch) -> Self { self.options.set( "cudnn_conv_algo_search", match search { - CuDNNConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE", - CuDNNConvAlgorithmSearch::Heuristic => "HEURISTIC", - CuDNNConvAlgorithmSearch::Default => "DEFAULT" + ConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE", + ConvAlgorithmSearch::Heuristic => "HEURISTIC", + ConvAlgorithmSearch::Default => "DEFAULT" } ); self } - /// Configure whether the [`Exhaustive`][CuDNNConvAlgorithmSearch::Exhaustive] search can use as much memory as it + /// Configure whether the [`Exhaustive`][ConvAlgorithmSearch::Exhaustive] search can use as much memory as it /// needs. /// /// The default is `true`. When `false`, the memory used for the search is limited to 32 MB, which will impact its /// ability to find an optimal convolution algorithm. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_conv_max_workspace(false).build(); + /// let ep = ep::CUDA::default().with_conv_max_workspace(false).build(); /// # Ok(()) /// # } /// ``` @@ -218,9 +218,9 @@ impl CUDAExecutionProvider { /// convolution operations that do not use 3-dimensional input shapes, or the *result* of such operations. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_conv1d_pad_to_nc1d(true).build(); + /// let ep = ep::CUDA::default().with_conv1d_pad_to_nc1d(true).build(); /// # Ok(()) /// # } /// ``` @@ -245,9 +245,9 @@ impl CUDAExecutionProvider { /// Consult the [ONNX Runtime documentation on CUDA graphs](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#using-cuda-graphs-preview) for more information. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_cuda_graph(true).build(); + /// let ep = ep::CUDA::default().with_cuda_graph(true).build(); /// # Ok(()) /// # } /// ``` @@ -262,9 +262,9 @@ impl CUDAExecutionProvider { /// `SkipLayerNorm`'s strict mode trades performance for accuracy. The default is `false` (strict mode disabled). /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_skip_layer_norm_strict_mode(true).build(); + /// let ep = ep::CUDA::default().with_skip_layer_norm_strict_mode(true).build(); /// # Ok(()) /// # } /// ``` @@ -285,9 +285,9 @@ impl CUDAExecutionProvider { /// `torch.backends.cuda.matmul.allow_tf32 = True` or `torch.set_float32_matmul_precision("medium")` in PyTorch. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_tf32(true).build(); + /// let ep = ep::CUDA::default().with_tf32(true).build(); /// # Ok(()) /// # } /// ``` @@ -303,9 +303,9 @@ impl CUDAExecutionProvider { /// convolution-heavy models on Tensor core-enabled GPUs may provide a significant performance improvement. /// /// ``` - /// # use ort::{execution_providers::cuda::CUDAExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default().with_prefer_nhwc(true).build(); + /// let ep = ep::CUDA::default().with_prefer_nhwc(true).build(); /// # Ok(()) /// # } /// ``` @@ -328,16 +328,18 @@ impl CUDAExecutionProvider { /// Configures the available backends used for `Attention` nodes. /// /// ``` - /// # use ort::{execution_providers::cuda::{CUDAExecutionProvider, CUDAAttentionBackend}, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = CUDAExecutionProvider::default() - /// .with_attention_backend(CUDAAttentionBackend::FLASH_ATTENTION | CUDAAttentionBackend::TRT_FUSED_ATTENTION) + /// let ep = ep::CUDA::default() + /// .with_attention_backend( + /// ep::cuda::AttentionBackend::FLASH_ATTENTION | ep::cuda::AttentionBackend::TRT_FUSED_ATTENTION + /// ) /// .build(); /// # Ok(()) /// # } /// ``` #[must_use] - pub fn with_attention_backend(mut self, flags: CUDAAttentionBackend) -> Self { + pub fn with_attention_backend(mut self, flags: AttentionBackend) -> Self { self.options.set("sdpa_kernel", flags.0.to_string()); self } @@ -352,7 +354,7 @@ impl CUDAExecutionProvider { // https://github.com/microsoft/onnxruntime/blob/fe8a10caa40f64a8fbd144e7049cf5b14c65542d/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc#L17 } -impl ExecutionProvider for CUDAExecutionProvider { +impl ExecutionProvider for CUDA { fn name(&self) -> &'static str { "CUDAExecutionProvider" } @@ -435,16 +437,16 @@ pub const CUDNN_DYLIBS: &[&str] = &[ /// /// ``` /// # use std::path::Path; -/// use ort::execution_providers::cuda; +/// use ort::ep; /// /// let cuda_root = Path::new(r#"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin"#); /// let cudnn_root = Path::new(r#"D:\cudnn_9.8.0"#); /// /// // Load CUDA & cuDNN -/// let _ = cuda::preload_dylibs(Some(cuda_root), Some(cudnn_root)); +/// let _ = ep::cuda::preload_dylibs(Some(cuda_root), Some(cudnn_root)); /// /// // Only preload cuDNN -/// let _ = cuda::preload_dylibs(None, Some(cudnn_root)); +/// let _ = ep::cuda::preload_dylibs(None, Some(cudnn_root)); /// ``` #[cfg_attr(docsrs, doc(cfg(any(feature = "preload-dylibs", feature = "load-dynamic"))))] #[cfg(feature = "preload-dylibs")] diff --git a/src/execution_providers/directml.rs b/src/ep/directml.rs similarity index 78% rename from src/execution_providers/directml.rs rename to src/ep/directml.rs index d193986..abb5ed4 100644 --- a/src/execution_providers/directml.rs +++ b/src/ep/directml.rs @@ -9,10 +9,10 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// with dynamically sized inputs, you can override individual dimensions by constructing the session with /// [`SessionBuilder::with_dimension_override`]: /// ```no_run -/// # use ort::{execution_providers::directml::DirectMLExecutionProvider, session::Session}; +/// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { /// let session = Session::builder()? -/// .with_execution_providers([DirectMLExecutionProvider::default().build()])? +/// .with_execution_providers([ep::DirectML::default().build()])? /// .with_dimension_override("batch", 1)? /// .with_dimension_override("seq_len", 512)? /// .commit_from_file("gpt2.onnx")?; @@ -20,19 +20,19 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// # } /// ``` #[derive(Debug, Default, Clone)] -pub struct DirectMLExecutionProvider { +pub struct DirectML { device_id: i32 } -super::impl_ep!(DirectMLExecutionProvider); +super::impl_ep!(DirectML); -impl DirectMLExecutionProvider { +impl DirectML { /// Configures which device the EP should use. /// /// ``` - /// # use ort::{execution_providers::directml::DirectMLExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = DirectMLExecutionProvider::default().with_device_id(1).build(); + /// let ep = ep::DirectML::default().with_device_id(1).build(); /// # Ok(()) /// # } /// ``` @@ -43,7 +43,7 @@ impl DirectMLExecutionProvider { } } -impl ExecutionProvider for DirectMLExecutionProvider { +impl ExecutionProvider for DirectML { fn name(&self) -> &'static str { "DmlExecutionProvider" } diff --git a/src/execution_providers/migraphx.rs b/src/ep/migraphx.rs similarity index 71% rename from src/execution_providers/migraphx.rs rename to src/ep/migraphx.rs index 943579e..592aacd 100644 --- a/src/execution_providers/migraphx.rs +++ b/src/ep/migraphx.rs @@ -6,7 +6,7 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// [MIGraphX execution provider](https://onnxruntime.ai/docs/execution-providers/MIGraphX-ExecutionProvider.html) for /// hardware acceleration with AMD GPUs. #[derive(Debug, Default, Clone)] -pub struct MIGraphXExecutionProvider { +pub struct MIGraphX { device_id: i32, enable_fp16: bool, enable_int8: bool, @@ -17,15 +17,15 @@ pub struct MIGraphXExecutionProvider { exhaustive_tune: bool } -super::impl_ep!(MIGraphXExecutionProvider); +super::impl_ep!(MIGraphX); -impl MIGraphXExecutionProvider { +impl MIGraphX { /// Configures which device the EP should use. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default().with_device_id(0).build(); + /// let ep = ep::MIGraphX::default().with_device_id(0).build(); /// # Ok(()) /// # } /// ``` @@ -38,9 +38,9 @@ impl MIGraphXExecutionProvider { /// Enable FP16 quantization for the model. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default().with_fp16(true).build(); + /// let ep = ep::MIGraphX::default().with_fp16(true).build(); /// # Ok(()) /// # } /// ``` @@ -51,15 +51,12 @@ impl MIGraphXExecutionProvider { } /// Enable 8-bit integer quantization for the model. Requires - /// [`MIGraphXExecutionProvider::with_int8_calibration_table`] to be set. + /// [`MIGraphX::with_int8_calibration_table`] to be set. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default() - /// .with_int8(true) - /// .with_int8_calibration_table("...", false) - /// .build(); + /// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build(); /// # Ok(()) /// # } /// ``` @@ -75,12 +72,9 @@ impl MIGraphXExecutionProvider { /// for the JSON dump format. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default() - /// .with_int8(true) - /// .with_int8_calibration_table("...", false) - /// .build(); + /// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build(); /// # Ok(()) /// # } /// ``` @@ -93,12 +87,12 @@ impl MIGraphXExecutionProvider { /// Save the compiled MIGraphX model to the given path. /// - /// The compiled model can then be loaded in subsequent runs with [`MIGraphXExecutionProvider::with_load_model`]. + /// The compiled model can then be loaded in subsequent runs with [`MIGraphX::with_load_model`]. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default().with_save_model("./compiled_model.mxr").build(); + /// let ep = ep::MIGraphX::default().with_save_model("./compiled_model.mxr").build(); /// # Ok(()) /// # } /// ``` @@ -108,13 +102,13 @@ impl MIGraphXExecutionProvider { self } - /// Load the compiled MIGraphX model (previously generated by [`MIGraphXExecutionProvider::with_save_model`]) from + /// Load the compiled MIGraphX model (previously generated by [`MIGraphX::with_save_model`]) from /// the given path. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default().with_load_model("./compiled_model.mxr").build(); + /// let ep = ep::MIGraphX::default().with_load_model("./compiled_model.mxr").build(); /// # Ok(()) /// # } /// ``` @@ -127,9 +121,9 @@ impl MIGraphXExecutionProvider { /// Enable exhaustive tuning; trades loading time for inference performance. /// /// ``` - /// # use ort::{execution_providers::migraphx::MIGraphXExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = MIGraphXExecutionProvider::default().with_exhaustive_tune(true).build(); + /// let ep = ep::MIGraphX::default().with_exhaustive_tune(true).build(); /// # Ok(()) /// # } /// ``` @@ -140,7 +134,7 @@ impl MIGraphXExecutionProvider { } } -impl ExecutionProvider for MIGraphXExecutionProvider { +impl ExecutionProvider for MIGraphX { fn name(&self) -> &'static str { "MIGraphXExecutionProvider" } diff --git a/src/execution_providers/mod.rs b/src/ep/mod.rs similarity index 74% rename from src/execution_providers/mod.rs rename to src/ep/mod.rs index 3199290..63bbfdc 100644 --- a/src/execution_providers/mod.rs +++ b/src/ep/mod.rs @@ -3,11 +3,11 @@ //! Sessions can be configured with execution providers via [`SessionBuilder::with_execution_providers`]: //! //! ```no_run -//! use ort::{execution_providers::CUDAExecutionProvider, session::Session}; +//! use ort::{ep, session::Session}; //! //! fn main() -> ort::Result<()> { //! let session = Session::builder()? -//! .with_execution_providers([CUDAExecutionProvider::default().build()])? +//! .with_execution_providers([ep::CUDA::default().build()])? //! .commit_from_file("model.onnx")?; //! //! Ok(()) @@ -24,47 +24,47 @@ use core::{ use crate::{char_p_to_string, error::Result, ortsys, session::builder::SessionBuilder, util::MiniMap}; pub mod cpu; -pub use self::cpu::CPUExecutionProvider; +pub use self::cpu::CPU; pub mod cuda; -pub use self::cuda::CUDAExecutionProvider; +pub use self::cuda::CUDA; pub mod tensorrt; -pub use self::tensorrt::TensorRTExecutionProvider; +pub use self::tensorrt::TensorRT; pub mod onednn; -pub use self::onednn::OneDNNExecutionProvider; +pub use self::onednn::OneDNN; pub mod acl; -pub use self::acl::ACLExecutionProvider; +pub use self::acl::ACL; pub mod openvino; -pub use self::openvino::OpenVINOExecutionProvider; +pub use self::openvino::OpenVINO; pub mod coreml; -pub use self::coreml::CoreMLExecutionProvider; +pub use self::coreml::CoreML; pub mod rocm; -pub use self::rocm::ROCmExecutionProvider; +pub use self::rocm::ROCm; pub mod cann; -pub use self::cann::CANNExecutionProvider; +pub use self::cann::CANN; pub mod directml; -pub use self::directml::DirectMLExecutionProvider; +pub use self::directml::DirectML; pub mod tvm; -pub use self::tvm::TVMExecutionProvider; +pub use self::tvm::TVM; pub mod nnapi; -pub use self::nnapi::NNAPIExecutionProvider; +pub use self::nnapi::NNAPI; pub mod qnn; -pub use self::qnn::QNNExecutionProvider; +pub use self::qnn::QNN; pub mod xnnpack; -pub use self::xnnpack::XNNPACKExecutionProvider; +pub use self::xnnpack::XNNPACK; pub mod armnn; -pub use self::armnn::ArmNNExecutionProvider; +pub use self::armnn::ArmNN; pub mod migraphx; -pub use self::migraphx::MIGraphXExecutionProvider; +pub use self::migraphx::MIGraphX; pub mod vitis; -pub use self::vitis::VitisAIExecutionProvider; +pub use self::vitis::Vitis; pub mod rknpu; -pub use self::rknpu::RKNPUExecutionProvider; +pub use self::rknpu::RKNPU; pub mod webgpu; -pub use self::webgpu::WebGPUExecutionProvider; +pub use self::webgpu::WebGPU; pub mod azure; -pub use self::azure::AzureExecutionProvider; +pub use self::azure::Azure; pub mod nvrtx; -pub use self::nvrtx::NVRTXExecutionProvider; +pub use self::nvrtx::NVRTX; #[cfg(target_arch = "wasm32")] pub mod wasm; #[cfg(target_arch = "wasm32")] @@ -82,14 +82,14 @@ pub trait ExecutionProvider: Send + Sync { /// Returns the identifier of this execution provider used internally by ONNX Runtime. /// /// This is the same as what's used in ONNX Runtime's Python API to register this execution provider, i.e. - /// [`TVMExecutionProvider`]'s identifier is `TvmExecutionProvider`. + /// [`TVM`]'s identifier is `TvmExecutionProvider`. fn name(&self) -> &'static str; /// Returns whether this execution provider is supported on this platform. /// /// For example, the CoreML execution provider implements this as: /// ```ignore - /// impl ExecutionProvider for CoreMLExecutionProvider { + /// impl ExecutionProvider for CoreML { /// fn supported_by_platform() -> bool { /// cfg!(target_vendor = "apple") /// } @@ -321,9 +321,9 @@ pub(crate) use define_ep_register; macro_rules! impl_ep { (arbitrary; $symbol:ident) => { - $crate::execution_providers::impl_ep!($symbol); + $crate::ep::impl_ep!($symbol); - impl $crate::execution_providers::ArbitrarilyConfigurableExecutionProvider for $symbol { + impl $crate::ep::ArbitrarilyConfigurableExecutionProvider for $symbol { fn with_arbitrary_config(mut self, key: impl ::alloc::string::ToString, value: impl ::alloc::string::ToString) -> Self { self.options.set(key.to_string(), value.to_string()); self @@ -333,14 +333,14 @@ macro_rules! impl_ep { ($symbol:ident) => { impl $symbol { #[must_use] - pub fn build(self) -> $crate::execution_providers::ExecutionProviderDispatch { + pub fn build(self) -> $crate::ep::ExecutionProviderDispatch { self.into() } } - impl From<$symbol> for $crate::execution_providers::ExecutionProviderDispatch { + impl From<$symbol> for $crate::ep::ExecutionProviderDispatch { fn from(value: $symbol) -> Self { - $crate::execution_providers::ExecutionProviderDispatch::new(value) + $crate::ep::ExecutionProviderDispatch::new(value) } } }; @@ -381,3 +381,75 @@ pub(crate) fn apply_execution_providers(session_builder: &mut SessionBuilder, ep } Ok(()) } + +#[deprecated = "import `ort::ep::ACL` instead"] +#[doc(hidden)] +pub use self::acl::ACL as ACLExecutionProvider; +#[deprecated = "import `ort::ep::ArmNN` instead"] +#[doc(hidden)] +pub use self::armnn::ArmNN as ArmNNExecutionProvider; +#[deprecated = "import `ort::ep::Azure` instead"] +#[doc(hidden)] +pub use self::azure::Azure as AzureExecutionProvider; +#[deprecated = "import `ort::ep::CANN` instead"] +#[doc(hidden)] +pub use self::cann::CANN as CANNExecutionProvider; +#[deprecated = "import `ort::ep::CoreML` instead"] +#[doc(hidden)] +pub use self::coreml::CoreML as CoreMLExecutionProvider; +#[deprecated = "import `ort::ep::CPU` instead"] +#[doc(hidden)] +pub use self::cpu::CPU as CPUExecutionProvider; +#[deprecated = "import `ort::ep::CUDA` instead"] +#[doc(hidden)] +pub use self::cuda::CUDA as CUDAExecutionProvider; +#[deprecated = "import `ort::ep::DirectML` instead"] +#[doc(hidden)] +pub use self::directml::DirectML as DirectMLExecutionProvider; +#[deprecated = "import `ort::ep::MIGraphX` instead"] +#[doc(hidden)] +pub use self::migraphx::MIGraphX as MIGraphXExecutionProvider; +#[deprecated = "import `ort::ep::NNAPI` instead"] +#[doc(hidden)] +pub use self::nnapi::NNAPI as NNAPIExecutionProvider; +#[deprecated = "import `ort::ep::NVRTX` instead"] +#[doc(hidden)] +pub use self::nvrtx::NVRTX as NVRTXExecutionProvider; +#[deprecated = "import `ort::ep::OneDNN` instead"] +#[doc(hidden)] +pub use self::onednn::OneDNN as OneDNNExecutionProvider; +#[deprecated = "import `ort::ep::OpenVINO` instead"] +#[doc(hidden)] +pub use self::openvino::OpenVINO as OpenVINOExecutionProvider; +#[deprecated = "import `ort::ep::QNN` instead"] +#[doc(hidden)] +pub use self::qnn::QNN as QNNExecutionProvider; +#[deprecated = "import `ort::ep::RKNPU` instead"] +#[doc(hidden)] +pub use self::rknpu::RKNPU as RKNPUExecutionProvider; +#[deprecated = "import `ort::ep::ROCm` instead"] +#[doc(hidden)] +pub use self::rocm::ROCm as ROCmExecutionProvider; +#[deprecated = "import `ort::ep::TensorRT` instead"] +#[doc(hidden)] +pub use self::tensorrt::TensorRT as TensorRTExecutionProvider; +#[deprecated = "import `ort::ep::TVM` instead"] +#[doc(hidden)] +pub use self::tvm::TVM as TVMExecutionProvider; +#[deprecated = "import `ort::ep::Vitis` instead"] +#[doc(hidden)] +pub use self::vitis::Vitis as VitisAIExecutionProvider; +#[deprecated = "import `ort::ep::WASM` instead"] +#[doc(hidden)] +#[cfg(target_arch = "wasm32")] +pub use self::wasm::WASM as WASMExecutionProvider; +#[deprecated = "import `ort::ep::WebGPU` instead"] +#[doc(hidden)] +pub use self::webgpu::WebGPU as WebGPUExecutionProvider; +#[deprecated = "import `ort::ep::WebNN` instead"] +#[doc(hidden)] +#[cfg(target_arch = "wasm32")] +pub use self::webnn::WebNN as WebNNExecutionProvider; +#[deprecated = "import `ort::ep::XNNPACK` instead"] +#[doc(hidden)] +pub use self::xnnpack::XNNPACK as XNNPACKExecutionProvider; diff --git a/src/execution_providers/nnapi.rs b/src/ep/nnapi.rs similarity index 94% rename from src/execution_providers/nnapi.rs rename to src/ep/nnapi.rs index 882218f..10b9081 100644 --- a/src/execution_providers/nnapi.rs +++ b/src/ep/nnapi.rs @@ -2,16 +2,16 @@ use super::{ExecutionProvider, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct NNAPIExecutionProvider { +pub struct NNAPI { use_fp16: bool, use_nchw: bool, disable_cpu: bool, cpu_only: bool } -super::impl_ep!(NNAPIExecutionProvider); +super::impl_ep!(NNAPI); -impl NNAPIExecutionProvider { +impl NNAPI { /// Use fp16 relaxation in NNAPI EP. This may improve performance but can also reduce accuracy due to the lower /// precision. #[must_use] @@ -49,7 +49,7 @@ impl NNAPIExecutionProvider { } } -impl ExecutionProvider for NNAPIExecutionProvider { +impl ExecutionProvider for NNAPI { fn name(&self) -> &'static str { "NnapiExecutionProvider" } diff --git a/src/execution_providers/nvrtx.rs b/src/ep/nvrtx.rs similarity index 88% rename from src/execution_providers/nvrtx.rs rename to src/ep/nvrtx.rs index 8c3227f..e9b81e2 100644 --- a/src/execution_providers/nvrtx.rs +++ b/src/ep/nvrtx.rs @@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct NVRTXExecutionProvider { +pub struct NVRTX { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; NVRTXExecutionProvider); +super::impl_ep!(arbitrary; NVRTX); -impl NVRTXExecutionProvider { +impl NVRTX { pub fn with_device_id(mut self, device_id: u32) -> Self { self.options.set("ep.nvtensorrtrtxexecutionprovider.device_id", device_id.to_string()); self @@ -23,7 +23,7 @@ impl NVRTXExecutionProvider { } } -impl ExecutionProvider for NVRTXExecutionProvider { +impl ExecutionProvider for NVRTX { fn name(&self) -> &'static str { "NvTensorRTRTXExecutionProvider" } diff --git a/src/execution_providers/onednn.rs b/src/ep/onednn.rs similarity index 81% rename from src/execution_providers/onednn.rs rename to src/ep/onednn.rs index aa487f0..e8143df 100644 --- a/src/execution_providers/onednn.rs +++ b/src/ep/onednn.rs @@ -4,20 +4,20 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// [oneDNN/DNNL execution provider](https://onnxruntime.ai/docs/execution-providers/oneDNN-ExecutionProvider.html) for /// Intel CPUs & iGPUs. #[derive(Debug, Default, Clone)] -#[doc(alias = "DNNLExecutionProvider")] -pub struct OneDNNExecutionProvider { +#[doc(alias = "DNNL")] +pub struct OneDNN { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; OneDNNExecutionProvider); +super::impl_ep!(arbitrary; OneDNN); -impl OneDNNExecutionProvider { +impl OneDNN { /// Enable/disable the usage of the arena allocator. /// /// ``` - /// # use ort::{execution_providers::onednn::OneDNNExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = OneDNNExecutionProvider::default().with_arena_allocator(true).build(); + /// let ep = ep::OneDNN::default().with_arena_allocator(true).build(); /// # Ok(()) /// # } /// ``` @@ -28,7 +28,7 @@ impl OneDNNExecutionProvider { } } -impl ExecutionProvider for OneDNNExecutionProvider { +impl ExecutionProvider for OneDNN { fn name(&self) -> &'static str { "DnnlExecutionProvider" } diff --git a/src/execution_providers/openvino.rs b/src/ep/openvino.rs similarity index 90% rename from src/execution_providers/openvino.rs rename to src/ep/openvino.rs index 31d9e96..6c7a781 100644 --- a/src/execution_providers/openvino.rs +++ b/src/ep/openvino.rs @@ -6,22 +6,22 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// [OpenVINO execution provider](https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html) for /// Intel CPUs/GPUs/NPUs. #[derive(Default, Debug, Clone)] -pub struct OpenVINOExecutionProvider { +pub struct OpenVINO { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; OpenVINOExecutionProvider); +super::impl_ep!(arbitrary; OpenVINO); -impl OpenVINOExecutionProvider { +impl OpenVINO { /// Overrides the accelerator hardware type and precision. /// /// `device_type` should be in the format `CPU`, `NPU`, `GPU`, `GPU.0`, `GPU.1`, etc. Heterogenous combinations are /// also supported in the format `HETERO:NPU,GPU`. /// /// ``` - /// # use ort::{execution_providers::openvino::OpenVINOExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = OpenVINOExecutionProvider::default().with_device_type("GPU.0").build(); + /// let ep = ep::OpenVINO::default().with_device_type("GPU.0").build(); /// # Ok(()) /// # } /// ``` @@ -81,7 +81,7 @@ impl OpenVINOExecutionProvider { } } -impl ExecutionProvider for OpenVINOExecutionProvider { +impl ExecutionProvider for OpenVINO { fn name(&self) -> &'static str { "OpenVINOExecutionProvider" } diff --git a/src/execution_providers/qnn.rs b/src/ep/qnn.rs similarity index 73% rename from src/execution_providers/qnn.rs rename to src/ep/qnn.rs index 30c940e..eb2319a 100644 --- a/src/execution_providers/qnn.rs +++ b/src/ep/qnn.rs @@ -4,7 +4,7 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum QNNPerformanceMode { +pub enum PerformanceMode { Default, Burst, Balanced, @@ -17,43 +17,43 @@ pub enum QNNPerformanceMode { SustainedHighPerformance } -impl QNNPerformanceMode { +impl PerformanceMode { #[must_use] pub fn as_str(&self) -> &'static str { match self { - QNNPerformanceMode::Default => "default", - QNNPerformanceMode::Burst => "burst", - QNNPerformanceMode::Balanced => "balanced", - QNNPerformanceMode::HighPerformance => "high_performance", - QNNPerformanceMode::HighPowerSaver => "high_power_saver", - QNNPerformanceMode::LowPowerSaver => "low_power_saver", - QNNPerformanceMode::LowBalanced => "low_balanced", - QNNPerformanceMode::PowerSaver => "power_saver", - QNNPerformanceMode::ExtremePowerSaver => "extreme_power_saver", - QNNPerformanceMode::SustainedHighPerformance => "sustained_high_performance" + PerformanceMode::Default => "default", + PerformanceMode::Burst => "burst", + PerformanceMode::Balanced => "balanced", + PerformanceMode::HighPerformance => "high_performance", + PerformanceMode::HighPowerSaver => "high_power_saver", + PerformanceMode::LowPowerSaver => "low_power_saver", + PerformanceMode::LowBalanced => "low_balanced", + PerformanceMode::PowerSaver => "power_saver", + PerformanceMode::ExtremePowerSaver => "extreme_power_saver", + PerformanceMode::SustainedHighPerformance => "sustained_high_performance" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum QNNProfilingLevel { +pub enum ProfilingLevel { Off, Basic, Detailed } -impl QNNProfilingLevel { +impl ProfilingLevel { pub fn as_str(&self) -> &'static str { match self { - QNNProfilingLevel::Off => "off", - QNNProfilingLevel::Basic => "basic", - QNNProfilingLevel::Detailed => "detailed" + ProfilingLevel::Off => "off", + ProfilingLevel::Basic => "basic", + ProfilingLevel::Detailed => "detailed" } } } #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum QNNContextPriority { +pub enum ContextPriority { Low, #[default] Normal, @@ -61,25 +61,25 @@ pub enum QNNContextPriority { High } -impl QNNContextPriority { +impl ContextPriority { pub fn as_str(&self) -> &'static str { match self { - QNNContextPriority::Low => "low", - QNNContextPriority::Normal => "normal", - QNNContextPriority::NormalHigh => "normal_high", - QNNContextPriority::High => "high" + ContextPriority::Low => "low", + ContextPriority::Normal => "normal", + ContextPriority::NormalHigh => "normal_high", + ContextPriority::High => "high" } } } #[derive(Debug, Default, Clone)] -pub struct QNNExecutionProvider { +pub struct QNN { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; QNNExecutionProvider); +super::impl_ep!(arbitrary; QNN); -impl QNNExecutionProvider { +impl QNN { /// The file path to QNN backend library. On Linux/Android, this is `libQnnCpu.so` to use the CPU backend, /// or `libQnnHtp.so` to use the accelerated backend. #[must_use] @@ -89,7 +89,7 @@ impl QNNExecutionProvider { } #[must_use] - pub fn with_profiling(mut self, level: QNNProfilingLevel) -> Self { + pub fn with_profiling(mut self, level: ProfilingLevel) -> Self { self.options.set("profiling_level", level.as_str()); self } @@ -114,7 +114,7 @@ impl QNNExecutionProvider { } #[must_use] - pub fn with_performance_mode(mut self, mode: QNNPerformanceMode) -> Self { + pub fn with_performance_mode(mut self, mode: PerformanceMode) -> Self { self.options.set("htp_performance_mode", mode.as_str()); self } @@ -126,7 +126,7 @@ impl QNNExecutionProvider { } #[must_use] - pub fn with_context_priority(mut self, priority: QNNContextPriority) -> Self { + pub fn with_context_priority(mut self, priority: ContextPriority) -> Self { self.options.set("qnn_context_priority", priority.as_str()); self } @@ -174,7 +174,7 @@ impl QNNExecutionProvider { } } -impl ExecutionProvider for QNNExecutionProvider { +impl ExecutionProvider for QNN { fn name(&self) -> &'static str { "QNNExecutionProvider" } diff --git a/src/execution_providers/rknpu.rs b/src/ep/rknpu.rs similarity index 86% rename from src/execution_providers/rknpu.rs rename to src/ep/rknpu.rs index 3e6db2d..1d1b846 100644 --- a/src/execution_providers/rknpu.rs +++ b/src/ep/rknpu.rs @@ -2,11 +2,11 @@ use super::{ExecutionProvider, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct RKNPUExecutionProvider {} +pub struct RKNPU {} -super::impl_ep!(RKNPUExecutionProvider); +super::impl_ep!(RKNPU); -impl ExecutionProvider for RKNPUExecutionProvider { +impl ExecutionProvider for RKNPU { fn name(&self) -> &'static str { "RknpuExecutionProvider" } diff --git a/src/execution_providers/rocm.rs b/src/ep/rocm.rs similarity index 95% rename from src/execution_providers/rocm.rs rename to src/ep/rocm.rs index 3be4897..ce33c14 100644 --- a/src/execution_providers/rocm.rs +++ b/src/ep/rocm.rs @@ -5,13 +5,13 @@ use super::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderOptions, Re use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct ROCmExecutionProvider { +pub struct ROCm { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; ROCmExecutionProvider); +super::impl_ep!(arbitrary; ROCm); -impl ROCmExecutionProvider { +impl ROCm { #[must_use] pub fn with_device_id(mut self, device_id: i32) -> Self { self.options.set("device_id", device_id.to_string()); @@ -86,7 +86,7 @@ impl ROCmExecutionProvider { } } -impl ExecutionProvider for ROCmExecutionProvider { +impl ExecutionProvider for ROCm { fn name(&self) -> &'static str { "ROCMExecutionProvider" } diff --git a/src/execution_providers/tensorrt.rs b/src/ep/tensorrt.rs similarity index 97% rename from src/execution_providers/tensorrt.rs rename to src/ep/tensorrt.rs index b592d08..d148f65 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/ep/tensorrt.rs @@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct TensorRTExecutionProvider { +pub struct TensorRT { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; TensorRTExecutionProvider); +super::impl_ep!(arbitrary; TensorRT); -impl TensorRTExecutionProvider { +impl TensorRT { #[must_use] pub fn with_device_id(mut self, device_id: i32) -> Self { self.options.set("device_id", device_id.to_string()); @@ -254,7 +254,7 @@ impl TensorRTExecutionProvider { } } -impl ExecutionProvider for TensorRTExecutionProvider { +impl ExecutionProvider for TensorRT { fn name(&self) -> &'static str { "TensorrtExecutionProvider" } diff --git a/src/execution_providers/tvm.rs b/src/ep/tvm.rs similarity index 90% rename from src/execution_providers/tvm.rs rename to src/ep/tvm.rs index 8f7a096..ad12d84 100644 --- a/src/execution_providers/tvm.rs +++ b/src/ep/tvm.rs @@ -4,22 +4,22 @@ use super::{ExecutionProvider, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TVMExecutorType { +pub enum ExecutorType { GraphExecutor, VirtualMachine } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TVMTuningType { +pub enum TuningType { AutoTVM, Ansor } #[derive(Debug, Default, Clone)] -pub struct TVMExecutionProvider { +pub struct TVM { /// Executor type used by TVM. There is a choice between two types, `GraphExecutor` and `VirtualMachine`. Default is - /// [`TVMExecutorType::VirtualMachine`]. - pub executor: Option, + /// [`ExecutorType::VirtualMachine`]. + pub executor: Option, /// Path to folder with set of files (`.ro-`, `.so`/`.dll`-files and weights) obtained after model tuning. pub so_folder: Option, /// Whether or not to perform a hash check on the model obtained in the `so_folder`. @@ -34,7 +34,7 @@ pub struct TVMExecutionProvider { /// `true` is recommended for best performance and is the default. pub freeze_weights: Option, pub to_nhwc: Option, - pub tuning_type: Option, + pub tuning_type: Option, /// Path to AutoTVM or Ansor tuning file which gives specifications for given model and target for the best /// performance. pub tuning_file_path: Option, @@ -42,9 +42,9 @@ pub struct TVMExecutionProvider { pub input_shapes: Option } -super::impl_ep!(TVMExecutionProvider); +super::impl_ep!(TVM); -impl ExecutionProvider for TVMExecutionProvider { +impl ExecutionProvider for TVM { fn name(&self) -> &'static str { "TvmExecutionProvider" } @@ -66,8 +66,8 @@ impl ExecutionProvider for TVMExecutionProvider { option_string.push(format!( "executor:{}", match executor { - TVMExecutorType::GraphExecutor => "graph", - TVMExecutorType::VirtualMachine => "vm" + ExecutorType::GraphExecutor => "graph", + ExecutorType::VirtualMachine => "vm" } )); } diff --git a/src/execution_providers/vitis.rs b/src/ep/vitis.rs similarity index 88% rename from src/execution_providers/vitis.rs rename to src/ep/vitis.rs index be4c872..f610503 100644 --- a/src/execution_providers/vitis.rs +++ b/src/ep/vitis.rs @@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct VitisAIExecutionProvider { +pub struct Vitis { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; VitisAIExecutionProvider); +super::impl_ep!(arbitrary; Vitis); -impl VitisAIExecutionProvider { +impl Vitis { pub fn with_config_file(mut self, config_file: impl ToString) -> Self { self.options.set("config_file", config_file.to_string()); self @@ -27,7 +27,7 @@ impl VitisAIExecutionProvider { } } -impl ExecutionProvider for VitisAIExecutionProvider { +impl ExecutionProvider for Vitis { fn name(&self) -> &'static str { "VitisAIExecutionProvider" } diff --git a/src/execution_providers/wasm.rs b/src/ep/wasm.rs similarity index 84% rename from src/execution_providers/wasm.rs rename to src/ep/wasm.rs index b4552a2..8fe808d 100644 --- a/src/execution_providers/wasm.rs +++ b/src/ep/wasm.rs @@ -2,13 +2,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct WASMExecutionProvider { +pub struct WASM { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; WASMExecutionProvider); +super::impl_ep!(arbitrary; WASM); -impl ExecutionProvider for WASMExecutionProvider { +impl ExecutionProvider for WASM { fn name(&self) -> &'static str { "WASMExecutionProvider" } diff --git a/src/execution_providers/webgpu.rs b/src/ep/webgpu.rs similarity index 68% rename from src/execution_providers/webgpu.rs rename to src/ep/webgpu.rs index 15386bf..82d0204 100644 --- a/src/execution_providers/webgpu.rs +++ b/src/ep/webgpu.rs @@ -4,84 +4,84 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUPreferredLayout { +pub enum PreferredLayout { NCHW, NHWC } -impl WebGPUPreferredLayout { +impl PreferredLayout { pub(crate) fn as_str(&self) -> &'static str { match self { - WebGPUPreferredLayout::NCHW => "NCHW", - WebGPUPreferredLayout::NHWC => "NHWC" + PreferredLayout::NCHW => "NCHW", + PreferredLayout::NHWC => "NHWC" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUDawnBackendType { +pub enum DawnBackendType { Vulkan, D3D12 } -impl WebGPUDawnBackendType { +impl DawnBackendType { pub(crate) fn as_str(&self) -> &'static str { match self { - WebGPUDawnBackendType::Vulkan => "Vulkan", - WebGPUDawnBackendType::D3D12 => "D3D12" + DawnBackendType::Vulkan => "Vulkan", + DawnBackendType::D3D12 => "D3D12" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUBufferCacheMode { +pub enum BufferCacheMode { Disabled, LazyRelease, Simple, Bucket } -impl WebGPUBufferCacheMode { +impl BufferCacheMode { pub(crate) fn as_str(&self) -> &'static str { match self { - WebGPUBufferCacheMode::Disabled => "disabled", - WebGPUBufferCacheMode::LazyRelease => "lazyRelease", - WebGPUBufferCacheMode::Simple => "simple", - WebGPUBufferCacheMode::Bucket => "bucket" + BufferCacheMode::Disabled => "disabled", + BufferCacheMode::LazyRelease => "lazyRelease", + BufferCacheMode::Simple => "simple", + BufferCacheMode::Bucket => "bucket" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUValidationMode { +pub enum ValidationMode { Disabled, WgpuOnly, Basic, Full } -impl WebGPUValidationMode { +impl ValidationMode { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebGPUValidationMode::Disabled => "disabled", - WebGPUValidationMode::WgpuOnly => "wgpuOnly", - WebGPUValidationMode::Basic => "basic", - WebGPUValidationMode::Full => "full" + ValidationMode::Disabled => "disabled", + ValidationMode::WgpuOnly => "wgpuOnly", + ValidationMode::Basic => "basic", + ValidationMode::Full => "full" } } } #[derive(Debug, Default, Clone)] -pub struct WebGPUExecutionProvider { +pub struct WebGPU { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; WebGPUExecutionProvider); +super::impl_ep!(arbitrary; WebGPU); -impl WebGPUExecutionProvider { +impl WebGPU { #[must_use] - pub fn with_preferred_layout(mut self, layout: WebGPUPreferredLayout) -> Self { + pub fn with_preferred_layout(mut self, layout: PreferredLayout) -> Self { self.options.set("ep.webgpuexecutionprovider.preferredLayout", layout.as_str()); self } @@ -100,7 +100,7 @@ impl WebGPUExecutionProvider { } #[must_use] - pub fn with_dawn_backend_type(mut self, backend_type: WebGPUDawnBackendType) -> Self { + pub fn with_dawn_backend_type(mut self, backend_type: DawnBackendType) -> Self { self.options.set("ep.webgpuexecutionprovider.dawnBackendType", backend_type.as_str()); self } @@ -112,31 +112,31 @@ impl WebGPUExecutionProvider { } #[must_use] - pub fn with_storage_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { + pub fn with_storage_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self { self.options.set("ep.webgpuexecutionprovider.storageBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_uniform_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { + pub fn with_uniform_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self { self.options.set("ep.webgpuexecutionprovider.uniformBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_query_resolve_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { + pub fn with_query_resolve_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self { self.options.set("ep.webgpuexecutionprovider.queryResolveBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_default_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { + pub fn with_default_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self { self.options.set("ep.webgpuexecutionprovider.defaultBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_validation_mode(mut self, mode: WebGPUValidationMode) -> Self { + pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self { self.options.set("ep.webgpuexecutionprovider.validationMode", mode.as_str()); self } @@ -155,7 +155,7 @@ impl WebGPUExecutionProvider { } } -impl ExecutionProvider for WebGPUExecutionProvider { +impl ExecutionProvider for WebGPU { fn name(&self) -> &'static str { "WebGpuExecutionProvider" } diff --git a/src/execution_providers/webnn.rs b/src/ep/webnn.rs similarity index 66% rename from src/execution_providers/webnn.rs rename to src/ep/webnn.rs index 6b65ef2..a03252b 100644 --- a/src/execution_providers/webnn.rs +++ b/src/ep/webnn.rs @@ -4,56 +4,56 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder}; #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebNNPowerPreference { +pub enum PowerPreference { #[default] Default, HighPerformance, LowPower } -impl WebNNPowerPreference { +impl PowerPreference { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebNNPowerPreference::Default => "default", - WebNNPowerPreference::HighPerformance => "high-performance", - WebNNPowerPreference::LowPower => "low-power" + PowerPreference::Default => "default", + PowerPreference::HighPerformance => "high-performance", + PowerPreference::LowPower => "low-power" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebNNDeviceType { +pub enum DeviceType { CPU, GPU, NPU } -impl WebNNDeviceType { +impl DeviceType { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebNNDeviceType::CPU => "cpu", - WebNNDeviceType::GPU => "gpu", - WebNNDeviceType::NPU => "npu" + DeviceType::CPU => "cpu", + DeviceType::GPU => "gpu", + DeviceType::NPU => "npu" } } } #[derive(Debug, Default, Clone)] -pub struct WebNNExecutionProvider { +pub struct WebNN { options: ExecutionProviderOptions } -impl WebNNExecutionProvider { +impl WebNN { #[must_use] - pub fn with_device_type(mut self, device_type: WebNNDeviceType) -> Self { + pub fn with_device_type(mut self, device_type: DeviceType) -> Self { self.options.set("deviceType", device_type.as_str()); self } #[must_use] - pub fn with_power_preference(mut self, pref: WebNNPowerPreference) -> Self { + pub fn with_power_preference(mut self, pref: PowerPreference) -> Self { self.options.set("powerPreference", pref.as_str()); self } @@ -65,9 +65,9 @@ impl WebNNExecutionProvider { } } -super::impl_ep!(arbitrary; WebNNExecutionProvider); +super::impl_ep!(arbitrary; WebNN); -impl ExecutionProvider for WebNNExecutionProvider { +impl ExecutionProvider for WebNN { fn name(&self) -> &'static str { "WebNNExecutionProvider" } diff --git a/src/execution_providers/xnnpack.rs b/src/ep/xnnpack.rs similarity index 80% rename from src/execution_providers/xnnpack.rs rename to src/ep/xnnpack.rs index 98c2617..6373c3a 100644 --- a/src/execution_providers/xnnpack.rs +++ b/src/ep/xnnpack.rs @@ -13,12 +13,12 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// disable the session intra-op threadpool to reduce contention: /// ```no_run /// # use core::num::NonZeroUsize; -/// # use ort::{execution_providers::xnnpack::XNNPACKExecutionProvider, session::Session}; +/// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { /// let session = Session::builder()? /// .with_intra_op_spinning(false)? /// .with_intra_threads(1)? -/// .with_execution_providers([XNNPACKExecutionProvider::default() +/// .with_execution_providers([ep::XNNPACK::default() /// .with_intra_op_num_threads(NonZeroUsize::new(4).unwrap()) /// .build()])? /// .commit_from_file("model.onnx")?; @@ -26,22 +26,20 @@ use crate::{error::Result, session::builder::SessionBuilder}; /// # } /// ``` #[derive(Debug, Default, Clone)] -pub struct XNNPACKExecutionProvider { +pub struct XNNPACK { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; XNNPACKExecutionProvider); +super::impl_ep!(arbitrary; XNNPACK); -impl XNNPACKExecutionProvider { +impl XNNPACK { /// Configures the number of threads to use for XNNPACK's internal intra-op threadpool. /// /// ``` /// # use core::num::NonZeroUsize; - /// # use ort::{execution_providers::xnnpack::XNNPACKExecutionProvider, session::Session}; + /// # use ort::{ep, session::Session}; /// # fn main() -> ort::Result<()> { - /// let ep = XNNPACKExecutionProvider::default() - /// .with_intra_op_num_threads(NonZeroUsize::new(4).unwrap()) - /// .build(); + /// let ep = ep::XNNPACK::default().with_intra_op_num_threads(NonZeroUsize::new(4).unwrap()).build(); /// # Ok(()) /// # } /// ``` @@ -52,7 +50,7 @@ impl XNNPACKExecutionProvider { } } -impl ExecutionProvider for XNNPACKExecutionProvider { +impl ExecutionProvider for XNNPACK { fn name(&self) -> &'static str { "XnnpackExecutionProvider" } diff --git a/src/io_binding.rs b/src/io_binding.rs index 3d8b638..52d5184 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -37,7 +37,7 @@ use crate::{ /// /// ```no_run /// # use ort::{ -/// # execution_providers::CUDAExecutionProvider, +/// # ep, /// # io_binding::IoBinding, /// # memory::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType}, /// # session::Session, @@ -45,10 +45,10 @@ use crate::{ /// # }; /// # fn main() -> ort::Result<()> { /// let mut text_encoder = Session::builder()? -/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .with_execution_providers([ep::CUDA::default().build()])? /// .commit_from_file("text_encoder.onnx")?; /// let mut unet = Session::builder()? -/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .with_execution_providers([ep::CUDA::default().build()])? /// .commit_from_file("unet.onnx")?; /// /// let text_condition = text_encoder diff --git a/src/lib.rs b/src/lib.rs index ebb00f7..0fc9b6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ #![allow(clippy::tabs_in_doc_comments, clippy::arc_with_non_send_sync)] #![allow(clippy::macro_metavars_in_unsafe)] #![warn(clippy::unwrap_used)] +#![deny(clippy::std_instead_of_alloc, clippy::std_instead_of_core)] #![cfg_attr(all(not(test), not(feature = "std")), no_std)] //!
@@ -28,8 +29,8 @@ pub mod adapter; pub mod compiler; pub mod editor; pub mod environment; +pub mod ep; pub mod error; -pub mod execution_providers; pub mod io_binding; pub mod logging; pub mod memory; @@ -48,6 +49,13 @@ pub mod api { pub use super::{api as ort, compiler::compile_api as compile, editor::editor_api as editor}; } +#[deprecated = "import execution providers from `ort::ep` instead"] +#[doc(hidden)] +pub mod execution_providers { + #[deprecated = "import execution providers from `ort::ep` instead"] + pub use super::ep::*; +} + use alloc::{borrow::ToOwned, boxed::Box, string::String}; use core::{ ffi::{CStr, c_char}, diff --git a/src/operator/mod.rs b/src/operator/mod.rs index 36ab2cb..3612b1e 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -36,7 +36,8 @@ pub trait Operator: Send { /// Returns the name of the operator. fn name(&self) -> &str; - /// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`. + /// Returns the internal name of the execution provider this operator runs on, e.g. `"CUDAExecutionProvider"` (see + /// [`ExecutionProvider::name`](crate::ep::ExecutionProvider::name)). /// /// If the returned type is not `None`, and the execution provider used by the session matches this operator's /// EP type, the value will not be copied to the CPU and you may use functions like [`Tensor::data_ptr`] to diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs index 1adbf77..a3cd93f 100644 --- a/src/session/builder/impl_commit.rs +++ b/src/session/builder/impl_commit.rs @@ -6,7 +6,6 @@ use core::{ ffi::c_void, marker::PhantomData, mem::replace, - ops::Deref, ptr::{self, NonNull} }; #[cfg(feature = "std")] @@ -23,8 +22,8 @@ use crate::error::{Error, ErrorCode}; use crate::util::OsCharArray; use crate::{ AsPointer, + ep::apply_execution_providers, error::Result, - execution_providers::apply_execution_providers, memory::Allocator, ortsys, session::{InMemorySession, Input, Output, Session, SharedSessionInner, dangerous} @@ -136,7 +135,7 @@ impl SessionBuilder { } #[cfg(all(feature = "std", not(target_arch = "wasm32")))] - fn commit_from_file_inner(mut self, model_path: &::Target) -> Result { + fn commit_from_file_inner(mut self, model_path: &::Target) -> Result { self.pre_commit()?; let session_ptr = if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { diff --git a/src/session/builder/impl_options.rs b/src/session/builder/impl_options.rs index a22d30d..bb11d1f 100644 --- a/src/session/builder/impl_options.rs +++ b/src/session/builder/impl_options.rs @@ -13,8 +13,8 @@ use crate::util::path_to_os_char; use crate::{ AsPointer, Error, ErrorCode, environment::{self, ThreadManager}, + ep::{ExecutionProviderDispatch, apply_execution_providers}, error::Result, - execution_providers::{ExecutionProviderDispatch, apply_execution_providers}, logging::{LogLevel, LoggerFunction}, memory::MemoryInfo, operator::OperatorDomain, @@ -34,8 +34,8 @@ impl SessionBuilder { /// ## Notes /// /// - **Indiscriminate use of [`SessionBuilder::with_execution_providers`] in a library** (e.g. always enabling - /// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by - /// providing a `Vec` of [`ExecutionProviderDispatch`]es. + /// CUDA) **is discouraged** unless you allow the user to configure the execution providers by providing a `Vec` + /// of [`ExecutionProviderDispatch`]es. pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Result { apply_execution_providers(&mut self, execution_providers.as_ref(), "session options")?; Ok(self) diff --git a/src/util/mod.rs b/src/util/mod.rs index 56b5c83..d67f5d8 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -36,8 +36,8 @@ pub(crate) use self::{ /// absolute path, then any subsequent requests to load a library called `foo` will use the `libfoo.so` we already /// loaded, instead of searching the system for a `foo` library. /// -/// See also [`crate::execution_providers::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all -/// required dependencies of the [CUDA execution provider](crate::execution_providers::CUDAExecutionProvider). +/// See also [`crate::ep::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all required dependencies +/// of the [CUDA execution provider](crate::ep::CUDA). /// /// ``` /// use std::env; diff --git a/src/value/impl_tensor/copy.rs b/src/value/impl_tensor/copy.rs index d737194..49c1495 100644 --- a/src/value/impl_tensor/copy.rs +++ b/src/value/impl_tensor/copy.rs @@ -6,7 +6,7 @@ use core::ops::{Deref, DerefMut}; use super::DefiniteTensorValueTypeMarker; use crate::{ - Error, OnceLock, Result, execution_providers as ep, + Error, OnceLock, Result, ep, io_binding::IoBinding, memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, session::{NoSelectedOutputs, RunOptions, Session, builder::GraphOptimizationLevel}, @@ -36,20 +36,20 @@ static IDENTITY_RUN_OPTIONS: OnceLock> = OnceLock: fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result { Ok(match device { - AllocationDevice::CPU => ep::CPUExecutionProvider::default().with_arena_allocator(false).build(), - AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDAExecutionProvider::default() + AllocationDevice::CPU => ep::CPU::default().with_arena_allocator(false).build(), + AllocationDevice::CUDA | AllocationDevice::CUDA_PINNED => ep::CUDA::default() .with_device_id(device_id) .with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested) .with_conv_max_workspace(false) - .with_conv_algorithm_search(ep::cuda::CuDNNConvAlgorithmSearch::Default) + .with_conv_algorithm_search(ep::cuda::ConvAlgorithmSearch::Default) .build(), - AllocationDevice::DIRECTML => ep::DirectMLExecutionProvider::default().with_device_id(device_id).build(), - AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANNExecutionProvider::default() + AllocationDevice::DIRECTML => ep::DirectML::default().with_device_id(device_id).build(), + AllocationDevice::CANN | AllocationDevice::CANN_PINNED => ep::CANN::default() .with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested) .with_cann_graph(false) .with_device_id(device_id) .build(), - AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINOExecutionProvider::default() + AllocationDevice::OPENVINO_CPU | AllocationDevice::OPENVINO_GPU => ep::OpenVINO::default() .with_num_threads(1) .with_device_type(if device == AllocationDevice::OPENVINO_CPU { "CPU".to_string() @@ -57,7 +57,7 @@ fn ep_for_device(device: AllocationDevice, device_id: i32) -> Result ep::ROCmExecutionProvider::default() + AllocationDevice::HIP | AllocationDevice::HIP_PINNED => ep::ROCm::default() .with_arena_extend_strategy(ep::ArenaExtendStrategy::SameAsRequested) .with_hip_graph(false) .with_exhaustive_conv_search(false) @@ -388,13 +388,13 @@ mod tests { #[cfg(feature = "cuda")] fn test_copy_into_cuda() -> crate::Result<()> { use crate::{ - execution_providers::CUDAExecutionProvider, + ep, memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, session::Session }; let dummy_session = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default().build()])? + .with_execution_providers([ep::CUDA::default().build()])? .commit_from_file("tests/data/upsample.ort")?; let allocator = Allocator::new(&dummy_session, MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?)?; @@ -412,13 +412,13 @@ mod tests { #[cfg(feature = "cuda")] fn test_copy_into_async_cuda() -> crate::Result<()> { use crate::{ - execution_providers::CUDAExecutionProvider, + ep, memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, session::Session }; let dummy_session = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default().build()])? + .with_execution_providers([ep::CUDA::default().build()])? .commit_from_file("tests/data/upsample.ort")?; let allocator = Allocator::new(&dummy_session, MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?)?; diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 3fb78c3..ceafd1c 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -213,7 +213,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { /// ``` /// /// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be - /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. + /// returned. See [`ndarray::ArrayRef::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view(input: impl TensorArrayData + 'a) -> Result> { let (shape, data, guard) = input.ref_parts()?; tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::(), T::into_tensor_element_type(), guard).map(|tensor| { @@ -252,7 +252,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// ``` /// /// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be - /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. + /// returned. See [`ndarray::ArrayRef::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view_mut(mut input: impl TensorArrayDataMut) -> Result> { let (shape, data, guard) = input.ref_parts_mut()?; tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, size_of::(), T::into_tensor_element_type(), guard).map(|tensor| { diff --git a/tests/leak-check/main.rs b/tests/leak-check/main.rs index d0492d8..11e1de0 100644 --- a/tests/leak-check/main.rs +++ b/tests/leak-check/main.rs @@ -1,6 +1,6 @@ use ort::{ adapter::Adapter, - execution_providers::CPUExecutionProvider, + ep, memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, operator::{ Operator, OperatorDomain, @@ -77,7 +77,7 @@ impl Operator for CustomOpTwo { } fn main() -> ort::Result<()> { - ort::init().with_execution_providers([CPUExecutionProvider::default().build()]).commit(); + ort::init().with_execution_providers([ep::CPU::default().build()]).commit(); let mut session = Session::builder()? .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?