mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
fix(sys): Fix double activation of CUDA features in resolve_dist (#531)
Fixes #530
This commit is contained in:
@@ -30,9 +30,9 @@ pub fn resolve_dist() -> Result<Distribution, Option<String>> {
|
||||
feature_set.push("wgpu");
|
||||
}
|
||||
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
|
||||
match vars::get(vars::CUDA_VERSION).as_deref() {
|
||||
Some("12") => feature_set.push("cu12"),
|
||||
Some("13") => feature_set.push("cu13"),
|
||||
let cuda_feature = match vars::get(vars::CUDA_VERSION).as_deref() {
|
||||
Some("12") => "cu12",
|
||||
Some("13") => "cu13",
|
||||
_ => {
|
||||
if let Some(cuda_home) = vars::get("CUDA_HOME")
|
||||
&& (cuda_home.contains("v13.") || cuda_home.contains("-13."))
|
||||
@@ -41,23 +41,23 @@ pub fn resolve_dist() -> Result<Distribution, Option<String>> {
|
||||
// On Windows this is usually C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1
|
||||
// On Linux this is usually /usr/local/cuda-13.1
|
||||
// so detecting v13. or -13. in the path usually works
|
||||
feature_set.push("cu13");
|
||||
"cu13"
|
||||
} else if let Some(ver) = vars::get("NV_CUDA_CUDART_VERSION")
|
||||
&& ver.starts_with("13.")
|
||||
{
|
||||
// Set by NVIDIA docker images (both devel & runtime)
|
||||
feature_set.push("cu13");
|
||||
"cu13"
|
||||
} else if let Ok(output) = Command::new("nvcc").arg("--version").output()
|
||||
&& let Ok(stdout) = str::from_utf8(&output.stdout)
|
||||
&& stdout.contains("Build cuda_13")
|
||||
{
|
||||
feature_set.push("cu13");
|
||||
"cu13"
|
||||
} else {
|
||||
feature_set.push("cu12");
|
||||
"cu12"
|
||||
}
|
||||
}
|
||||
}
|
||||
feature_set.push("cu12");
|
||||
};
|
||||
feature_set.push(cuda_feature);
|
||||
} else if cfg!(feature = "nvrtx") {
|
||||
// CUDA builds include NVRTX; only use the standalone NVRTX build if we aren't using CUDA as well
|
||||
feature_set.push("nvrtx");
|
||||
|
||||
Reference in New Issue
Block a user