fix(sys): Fix double activation of CUDA features in resolve_dist (#531)

Fixes #530
This commit is contained in:
Denis
2026-02-12 17:04:09 +03:00
committed by GitHub
parent 51b3c94254
commit 4caf2d6a76

View File

@@ -30,9 +30,9 @@ pub fn resolve_dist() -> Result<Distribution, Option<String>> {
feature_set.push("wgpu");
}
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
match vars::get(vars::CUDA_VERSION).as_deref() {
Some("12") => feature_set.push("cu12"),
Some("13") => feature_set.push("cu13"),
let cuda_feature = match vars::get(vars::CUDA_VERSION).as_deref() {
Some("12") => "cu12",
Some("13") => "cu13",
_ => {
if let Some(cuda_home) = vars::get("CUDA_HOME")
&& (cuda_home.contains("v13.") || cuda_home.contains("-13."))
@@ -41,23 +41,23 @@ pub fn resolve_dist() -> Result<Distribution, Option<String>> {
// On Windows this is usually C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1
// On Linux this is usually /usr/local/cuda-13.1
// so detecting v13. or -13. in the path usually works
feature_set.push("cu13");
"cu13"
} else if let Some(ver) = vars::get("NV_CUDA_CUDART_VERSION")
&& ver.starts_with("13.")
{
// Set by NVIDIA docker images (both devel & runtime)
feature_set.push("cu13");
"cu13"
} else if let Ok(output) = Command::new("nvcc").arg("--version").output()
&& let Ok(stdout) = str::from_utf8(&output.stdout)
&& stdout.contains("Build cuda_13")
{
feature_set.push("cu13");
"cu13"
} else {
feature_set.push("cu12");
"cu12"
}
}
}
feature_set.push("cu12");
};
feature_set.push(cuda_feature);
} else if cfg!(feature = "nvrtx") {
// CUDA builds include NVRTX; only use the standalone NVRTX build if we aren't using CUDA as well
feature_set.push("nvrtx");