diff --git a/Cargo.toml b/Cargo.toml index 4ffcf46..17ee26e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,7 +90,7 @@ cann = [ "ort-sys/cann" ] qnn = [ "ort-sys/qnn" ] webgpu = [ "ort-sys/webgpu" ] azure = [ "ort-sys/azure" ] -nv = [ "ort-sys/nv" ] +nvrtx = [ "ort-sys/nvrtx" ] [dependencies] ort-sys = { version = "=2.0.0-rc.10", path = "ort-sys", default-features = false } diff --git a/docs/content/perf/execution-providers.mdx b/docs/content/perf/execution-providers.mdx index e1a0493..beb460c 100644 --- a/docs/content/perf/execution-providers.mdx +++ b/docs/content/perf/execution-providers.mdx @@ -18,6 +18,7 @@ ONNX Runtime must be compiled from source with support for each execution provid |:-------- |:------- |:-------:|:------:| | [NVIDIA CUDA](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) | `cuda` | 🔷 | ✅ | | [NVIDIA TensorRT](https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html) | `tensorrt` | 🔷 | ✅ | +| [NVIDIA TensorRT RTX](https://onnxruntime.ai/docs/execution-providers/TensorRTRTX-ExecutionProvider.html) (NVRTX) | `nvrtx` | 🔷 | ✅ | | [Microsoft DirectML](https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html) | `directml` | 🔷 | ✅ | | [Apple CoreML](https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html) | `coreml` | 🔷 | ✅ | | [AMD ROCm](https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html) | `rocm` | 🔷 | ❌ | diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 82252dc..cf38b6b 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -45,7 +45,7 @@ cann = [] qnn = [] webgpu = [ "dep:glob" ] azure = [] -nv = [] +nvrtx = [] tls-rustls = [ "tls-rustls-no-provider", "ureq?/rustls" ] tls-rustls-no-provider = [ "__tls", "ureq?/rustls-no-provider" ] diff --git a/ort-sys/build/download/resolve.rs b/ort-sys/build/download/resolve.rs index 0e4f689..f8bbddc 100644 --- a/ort-sys/build/download/resolve.rs +++ b/ort-sys/build/download/resolve.rs @@ -29,7 +29,7 @@ pub fn resolve_dist() -> Result> { if cfg!(feature = "webgpu") { feature_set.push("wgpu"); } - if cfg!(feature = "nv") { + if cfg!(feature = "nvrtx") { feature_set.push("nvrtx"); } if cfg!(any(feature = "cuda", feature = "tensorrt")) { diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index 602baaf..dad8784 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -63,8 +63,8 @@ pub mod webgpu; pub use self::webgpu::WebGPUExecutionProvider; pub mod azure; pub use self::azure::AzureExecutionProvider; -pub mod nv; -pub use self::nv::NVExecutionProvider; +pub mod nvrtx; +pub use self::nvrtx::NVRTXExecutionProvider; #[cfg(target_arch = "wasm32")] pub mod wasm; #[cfg(target_arch = "wasm32")] diff --git a/src/execution_providers/nv.rs b/src/execution_providers/nvrtx.rs similarity index 84% rename from src/execution_providers/nv.rs rename to src/execution_providers/nvrtx.rs index 73308a8..8c3227f 100644 --- a/src/execution_providers/nv.rs +++ b/src/execution_providers/nvrtx.rs @@ -4,13 +4,13 @@ use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError}; use crate::{error::Result, session::builder::SessionBuilder}; #[derive(Debug, Default, Clone)] -pub struct NVExecutionProvider { +pub struct NVRTXExecutionProvider { options: ExecutionProviderOptions } -super::impl_ep!(arbitrary; NVExecutionProvider); +super::impl_ep!(arbitrary; NVRTXExecutionProvider); -impl NVExecutionProvider { +impl NVRTXExecutionProvider { pub fn with_device_id(mut self, device_id: u32) -> Self { self.options.set("ep.nvtensorrtrtxexecutionprovider.device_id", device_id.to_string()); self @@ -23,7 +23,7 @@ impl NVExecutionProvider { } } -impl ExecutionProvider for NVExecutionProvider { +impl ExecutionProvider for NVRTXExecutionProvider { fn name(&self) -> &'static str { "NvTensorRTRTXExecutionProvider" } @@ -34,7 +34,7 @@ impl ExecutionProvider for NVExecutionProvider { #[allow(unused, unreachable_code)] fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> { - #[cfg(any(feature = "load-dynamic", feature = "nv"))] + #[cfg(any(feature = "load-dynamic", feature = "nvrtx"))] { use crate::{AsPointer, ortsys};