Files
ort/build.rs
Carson M 7b3acaf95b 1.13.0
2022-11-26 15:16:30 -06:00

613 lines
20 KiB
Rust

#![allow(dead_code)]
use std::{
borrow::Cow,
env, fs,
io::{self, Read, Write},
path::{Path, PathBuf},
process::Stdio,
str::FromStr
};
const ORT_VERSION: &str = "1.12.1";
const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download";
const ORT_ENV_STRATEGY: &str = "ORT_STRATEGY";
const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION";
const ORT_ENV_CMAKE_TOOLCHAIN: &str = "ORT_CMAKE_TOOLCHAIN";
const ORT_ENV_CMAKE_PROGRAM: &str = "ORT_CMAKE_PROGRAM";
const ORT_ENV_PYTHON_PROGRAM: &str = "ORT_PYTHON_PROGRAM";
const ORT_EXTRACT_DIR: &str = "onnxruntime";
const ORT_GIT_DIR: &str = "onnxruntime";
const ORT_GIT_REPO: &str = "https://github.com/microsoft/onnxruntime";
const PROTOBUF_EXTRACT_DIR: &str = "protobuf";
const PROTOBUF_VERSION: &str = "3.11.2";
const PROTOBUF_RELEASE_BASE_URL: &str = "https://github.com/protocolbuffers/protobuf/releases/download";
macro_rules! incompatible_providers {
($($provider:ident),*) => {
#[allow(unused_imports)]
use casey::upper;
$(
if env::var(concat!("CARGO_FEATURE_", stringify!(upper!($provider)))).is_ok() {
panic!(concat!("Provider not available for this strategy and/or target: ", stringify!($provider)));
}
)*
}
}
trait OnnxPrebuiltArchive {
fn as_onnx_str(&self) -> Cow<str>;
}
#[derive(Debug)]
enum Architecture {
X86,
X86_64,
Arm,
Arm64
}
impl FromStr for Architecture {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"x86" => Ok(Architecture::X86),
"x86_64" => Ok(Architecture::X86_64),
"arm" => Ok(Architecture::Arm),
"aarch64" => Ok(Architecture::Arm64),
_ => Err(format!("Unsupported architecture: {}", s))
}
}
}
impl OnnxPrebuiltArchive for Architecture {
fn as_onnx_str(&self) -> Cow<str> {
match self {
Architecture::X86 => "x86".into(),
Architecture::X86_64 => "x64".into(),
Architecture::Arm => "arm".into(),
Architecture::Arm64 => "arm64".into()
}
}
}
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
enum Os {
Windows,
Linux,
MacOS
}
impl Os {
fn archive_extension(&self) -> &'static str {
match self {
Os::Windows => "zip",
Os::Linux => "tgz",
Os::MacOS => "tgz"
}
}
}
impl FromStr for Os {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"windows" => Ok(Os::Windows),
"linux" => Ok(Os::Linux),
"macos" => Ok(Os::MacOS),
_ => Err(format!("Unsupported OS: {}", s))
}
}
}
impl OnnxPrebuiltArchive for Os {
fn as_onnx_str(&self) -> Cow<str> {
match self {
Os::Windows => "win".into(),
Os::Linux => "linux".into(),
Os::MacOS => "osx".into()
}
}
}
#[derive(Debug)]
enum Accelerator {
None,
Gpu
}
impl OnnxPrebuiltArchive for Accelerator {
fn as_onnx_str(&self) -> Cow<str> {
match self {
Accelerator::None => "unaccelerated".into(),
Accelerator::Gpu => "gpu".into()
}
}
}
#[derive(Debug)]
struct Triplet {
os: Os,
arch: Architecture,
accelerator: Accelerator
}
impl OnnxPrebuiltArchive for Triplet {
fn as_onnx_str(&self) -> Cow<str> {
match (&self.os, &self.arch, &self.accelerator) {
(Os::Windows, Architecture::X86, Accelerator::None)
| (Os::Windows, Architecture::X86_64, Accelerator::None)
| (Os::Windows, Architecture::Arm, Accelerator::None)
| (Os::Windows, Architecture::Arm64, Accelerator::None)
| (Os::Linux, Architecture::X86_64, Accelerator::None)
| (Os::MacOS, Architecture::Arm64, Accelerator::None) => format!("{}-{}", self.os.as_onnx_str(), self.arch.as_onnx_str()).into(),
// for some reason, arm64/Linux uses `aarch64` instead of `arm64`
(Os::Linux, Architecture::Arm64, Accelerator::None) => format!("{}-{}", self.os.as_onnx_str(), "aarch64").into(),
// for another odd reason, x64/macOS uses `x86_64` instead of `x64`
(Os::MacOS, Architecture::X86_64, Accelerator::None) => format!("{}-{}", self.os.as_onnx_str(), "x86_64").into(),
(Os::Windows, Architecture::X86_64, Accelerator::Gpu) | (Os::Linux, Architecture::X86_64, Accelerator::Gpu) => {
format!("{}-{}-{}", self.os.as_onnx_str(), self.arch.as_onnx_str(), self.accelerator.as_onnx_str()).into()
}
_ => panic!(
"Microsoft does not provide ONNX Runtime downloads for triplet: {}-{}-{}; you may have to use the `system` strategy instead",
self.os.as_onnx_str(),
self.arch.as_onnx_str(),
self.accelerator.as_onnx_str()
)
}
}
}
fn prebuilt_onnx_url() -> (PathBuf, String) {
let accelerator = if cfg!(feature = "cuda") || cfg!(feature = "tensorrt") {
Accelerator::Gpu
} else {
Accelerator::None
};
let triplet = Triplet {
os: env::var("CARGO_CFG_TARGET_OS").expect("unable to get target OS").parse().unwrap(),
arch: env::var("CARGO_CFG_TARGET_ARCH").expect("unable to get target arch").parse().unwrap(),
accelerator
};
let prebuilt_archive = format!("onnxruntime-{}-{}.{}", triplet.as_onnx_str(), ORT_VERSION, triplet.os.archive_extension());
let prebuilt_url = format!("{}/v{}/{}", ORT_RELEASE_BASE_URL, ORT_VERSION, prebuilt_archive);
(PathBuf::from(prebuilt_archive), prebuilt_url)
}
fn prebuilt_protoc_url() -> (PathBuf, String) {
let host_platform = if cfg!(target_os = "windows") {
std::string::String::from("win32")
} else if cfg!(target_os = "macos") {
format!(
"osx-{}",
if cfg!(target_arch = "x86_64") {
"x86_64"
} else if cfg!(target_arch = "x86") {
"x86"
} else {
panic!("protoc does not have prebuilt binaries for darwin arm64 yet")
}
)
} else {
format!("linux-{}", if cfg!(target_arch = "x86_64") { "x86_64" } else { "x86_32" })
};
let prebuilt_archive = format!("protoc-{}-{}.zip", PROTOBUF_VERSION, host_platform);
let prebuilt_url = format!("{}/v{}/{}", PROTOBUF_RELEASE_BASE_URL, PROTOBUF_VERSION, prebuilt_archive);
(PathBuf::from(prebuilt_archive), prebuilt_url)
}
fn download<P>(source_url: &str, target_file: P)
where
P: AsRef<Path>
{
let resp = ureq::get(source_url)
.timeout(std::time::Duration::from_secs(300))
.call()
.unwrap_or_else(|err| panic!("ERROR: Failed to download {}: {:?}", source_url, err));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
let mut reader = resp.into_reader();
// FIXME: Save directly to the file
let mut buffer = vec![];
let read_len = reader.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer.len(), len);
assert_eq!(buffer.len(), read_len);
let f = fs::File::create(&target_file).unwrap();
let mut writer = io::BufWriter::new(f);
writer.write_all(&buffer).unwrap();
}
fn extract_archive(filename: &Path, output: &Path) {
match filename.extension().map(|e| e.to_str()) {
Some(Some("zip")) => extract_zip(filename, output),
#[cfg(not(target_os = "windows"))]
Some(Some("tgz")) => extract_tgz(filename, output),
_ => unimplemented!()
}
}
#[cfg(not(target_os = "windows"))]
fn extract_tgz(filename: &Path, output: &Path) {
let file = fs::File::open(&filename).unwrap();
let buf = io::BufReader::new(file);
let tar = flate2::read::GzDecoder::new(buf);
let mut archive = tar::Archive::new(tar);
archive.unpack(output).unwrap();
}
fn extract_zip(filename: &Path, outpath: &Path) {
let file = fs::File::open(filename).unwrap();
let buf = io::BufReader::new(file);
let mut archive = zip::ZipArchive::new(buf).unwrap();
for i in 0..archive.len() {
let mut file = archive.by_index(i).unwrap();
#[allow(deprecated)]
let outpath = outpath.join(file.enclosed_name().unwrap());
if !file.name().ends_with('/') {
println!("File {} extracted to \"{}\" ({} bytes)", i, outpath.as_path().display(), file.size());
if let Some(p) = outpath.parent() {
if !p.exists() {
fs::create_dir_all(p).unwrap();
}
}
let mut outfile = fs::File::create(&outpath).unwrap();
io::copy(&mut file, &mut outfile).unwrap();
}
}
}
fn copy_libraries(lib_dir: &Path, out_dir: &Path) {
// get the target directory - we need to place the dlls next to the executable so they can be properly loaded by windows
let out_dir = out_dir.parent().unwrap().parent().unwrap().parent().unwrap();
let lib_files = fs::read_dir(lib_dir).unwrap();
for lib_file in lib_files.filter(|e| {
e.as_ref()
.ok()
.map(|e| {
e.file_type().map(|e| e.is_file()).unwrap_or(false)
&& [".dll", ".so", ".dylib"]
.into_iter()
.any(|v| e.path().into_os_string().into_string().unwrap().contains(v))
})
.unwrap_or(false)
}) {
let lib_file = lib_file.unwrap();
let lib_path = lib_file.path();
let lib_name = lib_path.file_name().unwrap();
let out_path = out_dir.join(lib_name);
if !out_path.exists() {
fs::copy(&lib_path, out_path).unwrap();
}
}
}
fn prepare_libort_dir() -> (PathBuf, bool) {
let strategy = env::var(ORT_ENV_STRATEGY);
println!("[ort] strategy: {:?}", strategy.as_ref().map(String::as_str).unwrap_or_else(|_| "unknown"));
let target = env::var("TARGET").unwrap();
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
if target_arch.eq_ignore_ascii_case("aarch64") {
incompatible_providers![cuda, openvino, vitis_ai, tensorrt, migraphx, rocm];
} else if target_arch.eq_ignore_ascii_case("x86_64") {
incompatible_providers![vitis_ai, acl, armnn];
} else {
panic!("unsupported target architecture: {}", target_arch);
}
if target.contains("macos") {
incompatible_providers![cuda, openvino, tensorrt, directml, winml];
} else if target.contains("windows") {
incompatible_providers![coreml, vitis_ai, acl, armnn];
} else {
incompatible_providers![coreml, vitis_ai, directml, winml];
}
match strategy
.as_ref()
.map(String::as_str)
.unwrap_or_else(|_| if cfg!(feature = "prefer-compile-strategy") { "compile" } else { "download" })
{
"download" => {
if target.contains("macos") {
incompatible_providers![cuda, onednn, openvino, openmp, vitis_ai, tvm, tensorrt, migraphx, directml, winml, acl, armnn, rocm];
} else {
incompatible_providers![onednn, coreml, openvino, openmp, vitis_ai, tvm, migraphx, directml, winml, acl, armnn, rocm];
}
let (prebuilt_archive, prebuilt_url) = prebuilt_onnx_url();
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let extract_dir = out_dir.join(ORT_EXTRACT_DIR);
let downloaded_file = out_dir.join(&prebuilt_archive);
println!("cargo:rerun-if-changed={}", downloaded_file.display());
if !downloaded_file.exists() {
fs::create_dir_all(&out_dir).unwrap();
download(&prebuilt_url, &downloaded_file);
}
if !extract_dir.exists() {
extract_archive(&downloaded_file, &extract_dir);
}
let lib_dir = extract_dir.join(prebuilt_archive.file_stem().unwrap());
#[cfg(feature = "copy-dylibs")]
{
copy_libraries(&lib_dir.join("lib"), &out_dir);
}
(lib_dir, true)
}
"system" => {
let lib_dir = PathBuf::from(env::var(ORT_ENV_SYSTEM_LIB_LOCATION).expect("[ort] system strategy requires ORT_LIB_LOCATION env var to be set"));
#[cfg(feature = "copy-dylibs")]
{
copy_libraries(&lib_dir.join("lib"), &PathBuf::from(env::var("OUT_DIR").unwrap()));
}
(lib_dir, true)
}
"compile" => {
use std::process::Command;
let target = env::var("TARGET").unwrap();
if target.contains("macos") && !cfg!(target_os = "darwin") && env::var(ORT_ENV_CMAKE_PROGRAM).is_err() {
panic!("[ort] cross-compiling for macOS with the `compile` strategy requires `{}` to be set", ORT_ENV_CMAKE_PROGRAM);
}
let cmake = env::var(ORT_ENV_CMAKE_PROGRAM).unwrap_or_else(|_| "cmake".to_string());
let python = env::var(ORT_ENV_PYTHON_PROGRAM).unwrap_or_else(|_| {
if Command::new("python").arg("--version").status().unwrap().success() {
"python".to_string()
} else {
"python3".to_string()
}
});
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let required_cmds: &[&str] = &[&cmake, "python", "git"];
for cmd in required_cmds {
if Command::new(cmd).output().is_err() {
panic!("[ort] compile strategy requires `{}` to be installed", cmd);
}
}
println!("[ort] assuming C/C++ compilers are available");
Command::new("git")
.args([
"clone",
"--depth",
"1",
"--single-branch",
"--branch",
&format!("v{}", ORT_VERSION),
"--shallow-submodules",
"--recursive",
ORT_GIT_REPO,
ORT_GIT_DIR
])
.current_dir(&out_dir)
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.expect("failed to clone ORT repo");
// download prebuilt protoc binary
let (protoc_archive, protoc_url) = prebuilt_protoc_url();
let protoc_dir = out_dir.join(PROTOBUF_EXTRACT_DIR);
let protoc_archive_file = out_dir.join(protoc_archive);
println!("cargo:rerun-if-changed={}", protoc_archive_file.display());
if !protoc_archive_file.exists() {
download(&protoc_url, &protoc_archive_file);
}
if !protoc_dir.exists() {
extract_archive(&protoc_archive_file, &protoc_dir);
}
let protoc_file = if cfg!(target_os = "windows") { "protoc.exe" } else { "protoc" };
let protoc_file = protoc_dir.join("bin").join(protoc_file);
Command::new(protoc_file)
.args(["--help"])
.current_dir(&out_dir)
.stdout(Stdio::null())
.status()
.expect("error running `protoc --help`");
let root = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let _cmake_toolchain = env::var(ORT_ENV_CMAKE_TOOLCHAIN).map(PathBuf::from).unwrap_or(
if cfg!(target_os = "linux") && target.contains("aarch64") && target.contains("linux") {
root.join("toolchains").join("default-aarch64-linux-gnu.cmake")
} else if cfg!(target_os = "linux") && target.contains("aarch64") && target.contains("windows") {
root.join("toolchains").join("default-aarch64-w64-mingw32.cmake")
} else if cfg!(target_os = "linux") && target.contains("x86_64") && target.contains("windows") {
root.join("toolchains").join("default-x86_64-w64-mingw32.cmake")
} else {
PathBuf::new()
}
);
if cfg!(target_os = "linux") && target.contains("windows") && target.contains("aarch64") {
println!("[ort] detected cross compilation to Windows arm64, default toolchain will make bad assumptions.");
}
let mut command = Command::new(python);
command
.current_dir(&out_dir.join(ORT_GIT_DIR))
.stdout(Stdio::null())
.stderr(Stdio::inherit());
// note: --parallel will probably break something... parallel build *while* doing another parallel build (cargo)?
let mut build_args = vec!["tools/ci_build/build.py", "--build", "--update", "--parallel", "--skip_tests", "--skip_submodule_sync"];
let config = if cfg!(debug_assertions) {
"Debug"
} else if cfg!(feature = "minimal-build") {
"MinSizeRel"
} else {
"Release"
};
build_args.push("--config");
build_args.push(config);
if cfg!(feature = "minimal-build") {
build_args.push("--disable_exceptions");
}
build_args.push("--disable_rtti");
if target.contains("windows") {
build_args.push("--disable_memleak_checker");
}
if !cfg!(feature = "compile-static") {
build_args.push("--build_shared_lib");
} else {
build_args.push("--enable_msvc_static_runtime");
}
// onnxruntime will still build tests when --skip_tests is enabled, this filters out most of them
// this "fixes" compilation on alpine: https://github.com/microsoft/onnxruntime/issues/9155
// but causes other compilation errors: https://github.com/microsoft/onnxruntime/issues/7571
build_args.push("--cmake_extra_defines");
build_args.push("onnxruntime_BUILD_UNIT_TESTS=0");
// if we can use ninja on windows, great! let's use it!
// note that ninja + clang on windows is a total shitstorm so it's disabled for now
#[cfg(target_os = "windows")]
if Command::new("ninja").arg("--version").status().unwrap().success() && !Command::new("clang-cl").arg("--version").status().unwrap().success() {
build_args.push("--cmake_generator=Ninja");
} else {
// fuck
use vswhom::VsFindResult;
let vs_find_result = VsFindResult::search();
match vs_find_result {
Some(VsFindResult { vs_exe_path: Some(vs_exe_path), .. }) => {
let vs_exe_path = vs_exe_path.to_string_lossy();
// the one sane thing about visual studio is that the version numbers are somewhat predictable...
if vs_exe_path.contains("14.1") {
build_args.push("--cmake_generator=Visual Studio 15 2017");
} else if vs_exe_path.contains("14.2") {
build_args.push("--cmake_generator=Visual Studio 16 2019");
} else if vs_exe_path.contains("14.3") {
build_args.push("--cmake_generator=Visual Studio 17 2022");
}
}
Some(VsFindResult { vs_exe_path: None, .. }) | None => panic!("[ort] unable to find Visual Studio installation")
};
}
build_args.push("--build_dir=build");
command.args(build_args);
let code = command.status().expect("failed to run build script");
assert!(code.success(), "failed to build ONNX Runtime");
let lib_dir = out_dir.join(ORT_GIT_DIR).join("build").join(config);
let lib_dir = if cfg!(target_os = "windows") { lib_dir.join(config) } else { lib_dir };
for lib in &["common", "flatbuffers", "framework", "graph", "mlas", "optimizer", "providers", "session", "util"] {
let lib_path = lib_dir.join(if cfg!(target_os = "windows") {
format!("onnxruntime_{}.lib", lib)
} else {
format!("libonnxruntime_{}.a", lib)
});
// sanity check, just make sure the library exists before we try to link to it
if lib_path.exists() {
println!("cargo:rustc-link-lib=static=onnxruntime_{}", lib);
} else {
panic!("[ort] unable to find ONNX Runtime library: {}", lib_path.display());
}
}
// also need to link to onnx.lib and onnx_proto.lib
let external_lib_dir = lib_dir.parent().unwrap().join("external").join("onnx");
let external_lib_dir = if cfg!(target_os = "windows") { external_lib_dir.join(config) } else { external_lib_dir };
println!("cargo:rustc-link-search=native={}", external_lib_dir.display());
println!("cargo:rustc-link-lib=static=onnx");
println!("cargo:rustc-link-lib=static=onnx_proto");
println!("cargo:rustc-link-lib=onnxruntime_providers_shared");
println!("cargo:rustc-link-search=native={}", lib_dir.display());
(out_dir, false)
}
_ => panic!("[ort] unknown strategy: {} (valid options are `download` or `system`)", strategy.unwrap_or_else(|_| "unknown".to_string()))
}
}
#[cfg(not(feature = "generate-bindings"))]
fn generate_bindings(_include_dir: &Path) {
println!("[ort] bindings not generated automatically; using committed bindings instead.");
println!("[ort] enable the `generate-bindings` feature to generate fresh bindings.");
}
#[cfg(feature = "generate-bindings")]
fn generate_bindings(include_dir: &Path) {
let clang_args = &[
format!("-I{}", include_dir.display()),
format!("-I{}", include_dir.join("onnxruntime").join("core").join("session").display())
];
println!("cargo:rerun-if-changed=src/wrapper.h");
let bindings = bindgen::Builder::default()
.header("src/wrapper.h")
.clang_args(clang_args)
// Tell cargo to invalidate the built crate whenever any of the included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
// Set `size_t` to be translated to `usize` for win32 compatibility.
.size_t_is_usize(env::var("CARGO_CFG_TARGET_ARCH").unwrap().contains("x86"))
// Format using rustfmt
.rustfmt_bindings(true)
.rustified_enum("*")
.generate()
.expect("Unable to generate bindings");
// Write the bindings to (source controlled) src/onnx/bindings/<os>/<arch>/bindings.rs
let generated_file = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
.join("src")
.join("bindings")
.join(env::var("CARGO_CFG_TARGET_OS").unwrap())
.join(env::var("CARGO_CFG_TARGET_ARCH").unwrap())
.join("bindings.rs");
println!("cargo:rerun-if-changed={:?}", generated_file);
fs::create_dir_all(generated_file.parent().unwrap()).unwrap();
bindings.write_to_file(&generated_file).expect("Couldn't write bindings!");
}
#[cfg(feature = "disable-build-script")]
fn main() {}
#[cfg(not(feature = "disable-build-script"))]
fn main() {
let (install_dir, needs_link) = prepare_libort_dir();
let include_dir = install_dir.join("include");
let lib_dir = install_dir.join("lib");
if needs_link {
println!("cargo:rustc-link-lib=onnxruntime");
println!("cargo:rustc-link-search=native={}", lib_dir.display());
}
println!("cargo:rerun-if-env-changed={}", ORT_ENV_STRATEGY);
println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION);
generate_bindings(&include_dir);
}