This commit is contained in:
Carson M
2022-11-26 15:16:30 -06:00
commit 7b3acaf95b
52 changed files with 81698 additions and 0 deletions

26
.editorconfig Normal file
View File

@@ -0,0 +1,26 @@
root = true
[*]
indent_style = tab
tab_width = 4
charset = utf-8
end_of_line = lf
insert_final_newline = true
curly_bracket_next_line = false
spaces_around_operators = true
spaces_around_brackets = both
[*.{rs,cc,hh,js,ts}]
trim_trailing_whitespace = true
max_line_length = 160
[*.{html,css,js,ts}]
quote_type = single
[*.yml]
indent_style = space
tab_width = 2
[*.{md,mdx}]
indent_style = space
tab_width = 4

18
.gitattributes vendored Normal file
View File

@@ -0,0 +1,18 @@
# Properly detect languages on Github
*.h linguist-language=cpp
*.inc linguist-language=cpp
third_party/* linguist-vendored
# Normalize EOL for all files that Git considers text files
* text=auto eol=lf
# Except for bat files, which are Windows only files
*.bat eol=crlf
# The above only works properly for Git 2.10+, so for older versions
# we need to manually list the binary files we don't want modified.
*.icns binary
*.ico binary
*.jar binary
*.png binary
*.ttf binary
*.tza binary

81
.github/workflows/code-quality.yml vendored Normal file
View File

@@ -0,0 +1,81 @@
name: Code Quality
on:
push:
branches:
- 'main'
paths:
- '.github/workflows/code-quality.yml'
- 'src/**/*.rs'
- 'build.rs'
- 'Cargo.toml'
- '.cargo/**/*'
- 'tests/**/*'
pull_request:
paths:
- '.github/workflows/code-quality.yml'
- 'src/**/*.rs'
- 'build.rs'
- 'Cargo.toml'
- '.cargo/**/*'
- 'tests/**/*'
env:
RUST_BACKTRACE: 1
CARGO_INCREMENTAL: 0
CARGO_PROFILE_DEV_DEBUG: 0
jobs:
lint-and-fmt:
name: Lint & format
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout sources
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly # required for some rustfmt/clippy features
override: true
components: rustfmt, clippy
- name: Check fmt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- name: Run clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-targets
coverage:
name: Code coverage
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout sources
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- name: Get Rust version
id: rust-version
run: echo "::set-output name=version::$(cargo --version | cut -d ' ' -f 2)"
shell: bash
- uses: actions/cache@v2
id: tarpaulin-cache
with:
path: ~/.cargo/bin/cargo-tarpaulin
key: ${{ runner.os }}-cargo-${{ steps.rustc-version.outputs.version }}
- name: Install tarpaulin
if: steps.tarpaulin-cache.outputs.cache-hit != 'true'
run: cargo install cargo-tarpaulin
- name: Generate code coverage
run: |
cargo tarpaulin --verbose --timeout 120 --out Xml
- name: Upload to codecov.io
uses: codecov/codecov-action@v2
with:
fail_ci_if_error: true

40
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: Run cargo tests
on:
push:
branches:
- 'main'
paths:
- '.github/workflows/test.yml'
- 'src/**/*.rs'
- 'build.rs'
- 'Cargo.toml'
- '.cargo/**/*'
- 'tests/**/*'
pull_request:
paths:
- '.github/workflows/test.yml'
- 'src/**/*.rs'
- 'build.rs'
- 'Cargo.toml'
- '.cargo/**/*'
- 'tests/**/*'
env:
RUST_BACKTRACE: 1
CARGO_INCREMENTAL: 0
CARGO_PROFILE_DEV_DEBUG: 0
jobs:
test:
name: Run tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install stable Rust toolchain
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
- uses: Swatinem/rust-cache@v1
- name: Run tests
# do not run doctests until rust-lang/cargo#10469 is merged
run: |
cargo test --verbose --lib

188
.gitignore vendored Normal file
View File

@@ -0,0 +1,188 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
# Prerequisites
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.out
*.app
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
.pnpm-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Snowpack dependency directory (https://snowpack.dev/)
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variable files
.env
.env.development.local
.env.test.local
.env.production.local
.env.local
# parcel-bundler cache (https://parceljs.org/)
.cache
.parcel-cache
# Next.js build output
.next
out
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# vuepress v2.x temp and cache directory
.temp
.cache
# Docusaurus cache and generated files
.docusaurus
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# yarn v2
.yarn/cache
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
# Node binaries
*.node
# Wix tools
WixTools/
# ONNX Runtime downloaded models
**/*.onnx
!tests/data/*.onnx

7
.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
"recommendations": [
"rust-lang.rust-analyzer",
"dustypomerleau.rust-syntax",
"tamasfe.even-better-toml"
]
}

11
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,11 @@
{
"rust-analyzer.checkOnSave.command": "clippy",
"rust-analyzer.rustfmt.extraArgs": [ "+nightly" ],
"[rust]": {
"editor.formatOnSave": true,
"editor.defaultFormatter": "rust-lang.rust-analyzer",
"editor.semanticHighlighting.enabled": true
},
"rust-analyzer.cachePriming.enable": true,
"rust-analyzer.diagnostics.experimental.enable": true
}

101
Cargo.toml Normal file
View File

@@ -0,0 +1,101 @@
[package]
name = "ort"
description = "Yet another ONNX Runtime wrapper"
version = "1.13.0"
edition = "2021"
license = "MIT/Apache-2.0"
repository = "https://github.com/pykeio/ort"
readme = "README.md"
keywords = [ "machine-learning", "ai", "ml" ]
categories = [ "algorithms", "mathematics", "science" ]
authors = [
"pyke.io",
"Nicolas Bigaouette <nbigaouette@gmail.com>"
]
include = [ "src/", "examples/", "tests/", "toolchains/", "build.rs", "LICENSE", "README.md" ]
[profile.release]
opt-level = 3
lto = true
strip = true
codegen-units = 1
[package.metadata.docs.rs]
features = [ "half", "fetch-models", "copy-dylibs" ]
[features]
default = [ "half", "fetch-models", "copy-dylibs" ]
use-half-intrinsics = [ "half/use-intrinsics" ]
# used to prevent issues with docs.rs
disable-build-script = []
fetch-models = [ "ureq" ]
generate-bindings = [ "bindgen" ]
copy-dylibs = []
# ONNX compile flags
prefer-compile-strategy = []
prefer-system-strategy = []
prefer-dynamic-libs = []
minimal-build = []
training = []
experimental = []
mimalloc = []
compile-static = []
cuda = []
tensorrt = []
openvino = []
onednn = []
directml = []
snpe = []
nnapi = []
coreml = []
xnnpack = []
rocm = []
acl = []
armnn = []
tvm = []
migraphx = []
rknpu = []
vitis = []
cann = []
[dependencies]
ndarray = "0.15"
num-complex = "0.4"
num-traits = "0.2"
thiserror = "1.0"
# onnx
ureq = { version = "2.1", optional = true }
lazy_static = "1.4"
tracing = "0.1"
half = { version = "2.1", optional = true }
[target.'cfg(unix)'.dependencies]
libc = "0.2"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = [ "std", "libloaderapi" ] }
[dev-dependencies]
ureq = "2.1"
image = "0.24"
test-log = { version = "0.2", default-features = false, features = [ "trace" ] }
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
rust_tokenizers = "7.0"
rand = "0.8"
[build-dependencies]
casey = "0.3"
bindgen = { version = "0.63", optional = true }
ureq = "2.1"
zip = { version = "0.6", default-features = false, features = [ "deflate" ] }
[target.'cfg(not(target_os = "windows"))'.build-dependencies]
flate2 = "1.0"
tar = "0.4"
[target.'cfg(target_os = "windows")'.build-dependencies]
vswhom = "0.1"

202
LICENSE Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

85
README.md Normal file
View File

@@ -0,0 +1,85 @@
<div align=center>
<h1><code>ort</code> - ONNX Runtime Rust bindings</h1>
<a href="https://app.codecov.io/gh/pykeio/ort" target="_blank"><img alt="Coverage Results" src="https://img.shields.io/codecov/c/gh/pykeio/ort?style=for-the-badge"></a> <a href="https://github.com/pykeio/ort/actions/workflows/test.yml"><img alt="GitHub Workflow Status" src="https://img.shields.io/github/workflow/status/pykeio/ort/Run%20cargo%20tests?style=for-the-badge"></a>
</div>
`ort` is yet another ONNX Runtime wrapper for Rust based on [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). `ort` is updated for ONNX Runtime 1.13.1 and contains many API improvements & fixes.
## Cargo features
- `fetch-models`: Enables fetching models from the ONNX Model Zoo; not recommended for production.
- `generate-bindings`: Update/generate ONNX Runtime bindings with `bindgen`. Requires [libclang](https://clang.llvm.org/doxygen/group__CINDEX.html).
- `copy-dylibs`: Copy dynamic libraries to the Cargo `target` folder. Highly recommended on Windows, where the operating system may have an older version of ONNX Runtime bundled.
- `prefer-system-strategy`: Uses the `system` compile strategy by default; requires users to provide ONNX Runtime libraries.
- `prefer-dynamic-libs`: By default, if the path pointed to by `ORT_LIB_LOCATION` contains static libraries, `ort` will link to them rather than dynamic libraries. This feature prefers linking to dynamic libraries instead.
- `prefer-compile-strategy`: Uses the `compile` strategy by default; will take a *very* long time and is currently untested, but allows for easy static linking, avoiding [the DLL hell](#shared-library-hell).
- `compile-static`: Compiles ONNX Runtime as a static library.
- `mimalloc`: Uses the (usually) faster mimalloc memory allocation library instead of the platform default.
- `experimental`: Compiles Microsoft experimental operators.
- `training`: Enables training via ONNX Runtime. Currently unavailable through high-level bindings.
- `minimal-build`: Builds ONNX Runtime without ONNX model loading. Drastically reduces size. Recommended for release builds.
- **Execution providers**: These are required for both building **and** using execution providers. Do not enable any of these features unless you are using the `compile` strategy or you are using the `system` strategy with binaries that support these execution providers, otherwise you'll run into linking errors.
- `cuda`: Enables the CUDA execution provider for Maxwell (7xx) NVIDIA GPUs and above. Requires CUDA v11.6+.
- `tensorrt`: Enables the TensorRT execution provider for GeForce 9xx series NVIDIA GPUs and above; requires CUDA v11.6+ and TensorRT v8.4+.
- `openvino`: Enables the OpenVINO execution provider for 6th+ generation Intel Core CPUs.
- `onednn`: Enables the oneDNN execution provider for x86/x64 targets.
- `directml`: Enables the DirectML execution provider for Windows x86/x64 targets with dedicated GPUs supporting DirectX 12.
- `snpe`: Enables the SNPE execution provider for Qualcomm Snapdragon CPUs & Adreno GPUs.
- `nnapi`: Enables the Android Neural Networks API (NNAPI) execution provider.
- `coreml`: Enables the CoreML execution provider for macOS/iOS targets.
- `xnnpack`: Enables the [XNNPACK](https://github.com/google/XNNPACK) backend for WebAssembly and Android.
- `rocm`: Enables the ROCm execution provider for AMD ROCm-enabled GPUs.
- `acl`: Enables the ARM Compute Library execution provider for multi-core ARM v8 processors.
- `armnn`: Enables the ArmNN execution provider for ARM v8 targets.
- `tvm`: Enables the **preview** Apache TVM execution provider.
- `migraphx`: Enables the MIGraphX execution provider for Windows x86/x64 targets with dedicated AMD GPUs.
- `rknpu`: Enables the RKNPU execution provider for Rockchip NPUs.
- `vitis`: Enables Xilinx's Vitis-AI execution provider for U200/U250 accelerators.
- `cann`: Enables the Huawei Compute Architecture for Neural Networks (CANN) execution provider.
- `half`: Builds support for `float16`/`bfloat16` ONNX tensors.
- `use-half-intrinsics`: Use intrinsics in the `half` crate for faster operations when dealing with `float16`/`bfloat16` ONNX tensors.
## Execution providers
To use other execution providers, you must explicitly enable them via their Cargo features. Using the `compile` strategy, everything should just work™. Using the `system` strategy, ensure that the binaries you are linking to have been built with the execution providers you want to use, otherwise you'll get linking errors. After that, configuring & enabling these execution providers can be done through `SessionBuilder::execution_providers()`.
Requesting an execution provider via e.g. `ExecutionProviderBuilder::cuda()` will silently fail if that EP is not available on the system or encounters an error and falls back to the next requested execution provider or to the CPU provider if no requested providers are available. If you must know why the execution provider is unavailable, use `ExecutionProviderBuilder::try_*()`, e.g. `try_cuda()`.
For prebuilt Microsoft binaries, you can enable the CUDA or TensorRT execution providers for Windows and Linux via the `cuda` and `tensorrt` Cargo features respectively. **No other execution providers are supported** in these binaries and enabling other features will fail. To use other execution providers, you must build ONNX Runtime yourself to be able to use them.
## Shared library hell
Because compiling ONNX Runtime from source takes so long (and static linking is not recommended by Microsoft), it may be easier to compile ONNX Runtime as a shared library or use prebuilt DLLs. However, this can cause some issues with library paths and load orders.
### Windows
Some versions of Windows come bundled with `onnxruntime.dll` in the System32 folder. On Windows 11 build 22598.1, `onnxruntime.dll` is on version 1.10.0, while `ort` requires 1.13.1, so execution will fail because [system DLLs take precedence](https://learn.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order). Luckily though, DLLs in the same folder as the application have higher priority; `ort` can automatically copy the DLLs to the Cargo target folder when the `copy-dylibs` feature is enabled.
Note that, when running tests/benchmarks for the first time, you'll have to manually copy the `target/debug/onnxruntime*.dll` files to `target/debug/deps/`, or `target/debug/examples/` for examples. It should Just Work™ when building/running a binary, however.
### Linux
You'll either have to copy `libonnxruntime.so` to a known lib location (e.g. `/usr/lib`) or enable rpath if you have the `copy-dylibs` feature enabled.
In `Cargo.toml`:
```toml
[profile.dev]
rpath = true
[profile.release]
rpath = true
# do this for all profiles
```
In `.cargo/config.toml`:
```toml
[target.x86_64-unknown-linux-gnu]
rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ]
# do this for all Linux targets as well
```
### macOS
macOS follows the same procedure as Linux, except the rpath should point to `@loader_path` rather than `$ORIGIN`:
```toml
# .cargo/config.toml
[target.x86_64-apple-darwin]
rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ]
```

612
build.rs Normal file
View File

@@ -0,0 +1,612 @@
#![allow(dead_code)]
use std::{
borrow::Cow,
env, fs,
io::{self, Read, Write},
path::{Path, PathBuf},
process::Stdio,
str::FromStr
};
const ORT_VERSION: &str = "1.12.1";
const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download";
const ORT_ENV_STRATEGY: &str = "ORT_STRATEGY";
const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION";
const ORT_ENV_CMAKE_TOOLCHAIN: &str = "ORT_CMAKE_TOOLCHAIN";
const ORT_ENV_CMAKE_PROGRAM: &str = "ORT_CMAKE_PROGRAM";
const ORT_ENV_PYTHON_PROGRAM: &str = "ORT_PYTHON_PROGRAM";
const ORT_EXTRACT_DIR: &str = "onnxruntime";
const ORT_GIT_DIR: &str = "onnxruntime";
const ORT_GIT_REPO: &str = "https://github.com/microsoft/onnxruntime";
const PROTOBUF_EXTRACT_DIR: &str = "protobuf";
const PROTOBUF_VERSION: &str = "3.11.2";
const PROTOBUF_RELEASE_BASE_URL: &str = "https://github.com/protocolbuffers/protobuf/releases/download";
macro_rules! incompatible_providers {
($($provider:ident),*) => {
#[allow(unused_imports)]
use casey::upper;
$(
if env::var(concat!("CARGO_FEATURE_", stringify!(upper!($provider)))).is_ok() {
panic!(concat!("Provider not available for this strategy and/or target: ", stringify!($provider)));
}
)*
}
}
trait OnnxPrebuiltArchive {
fn as_onnx_str(&self) -> Cow<str>;
}
#[derive(Debug)]
enum Architecture {
X86,
X86_64,
Arm,
Arm64
}
impl FromStr for Architecture {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"x86" => Ok(Architecture::X86),
"x86_64" => Ok(Architecture::X86_64),
"arm" => Ok(Architecture::Arm),
"aarch64" => Ok(Architecture::Arm64),
_ => Err(format!("Unsupported architecture: {}", s))
}
}
}
impl OnnxPrebuiltArchive for Architecture {
fn as_onnx_str(&self) -> Cow<str> {
match self {
Architecture::X86 => "x86".into(),
Architecture::X86_64 => "x64".into(),
Architecture::Arm => "arm".into(),
Architecture::Arm64 => "arm64".into()
}
}
}
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
enum Os {
Windows,
Linux,
MacOS
}
impl Os {
fn archive_extension(&self) -> &'static str {
match self {
Os::Windows => "zip",
Os::Linux => "tgz",
Os::MacOS => "tgz"
}
}
}
impl FromStr for Os {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"windows" => Ok(Os::Windows),
"linux" => Ok(Os::Linux),
"macos" => Ok(Os::MacOS),
_ => Err(format!("Unsupported OS: {}", s))
}
}
}
impl OnnxPrebuiltArchive for Os {
fn as_onnx_str(&self) -> Cow<str> {
match self {
Os::Windows => "win".into(),
Os::Linux => "linux".into(),
Os::MacOS => "osx".into()
}
}
}
#[derive(Debug)]
enum Accelerator {
None,
Gpu
}
impl OnnxPrebuiltArchive for Accelerator {
fn as_onnx_str(&self) -> Cow<str> {
match self {
Accelerator::None => "unaccelerated".into(),
Accelerator::Gpu => "gpu".into()
}
}
}
#[derive(Debug)]
struct Triplet {
os: Os,
arch: Architecture,
accelerator: Accelerator
}
impl OnnxPrebuiltArchive for Triplet {
fn as_onnx_str(&self) -> Cow<str> {
match (&self.os, &self.arch, &self.accelerator) {
(Os::Windows, Architecture::X86, Accelerator::None)
| (Os::Windows, Architecture::X86_64, Accelerator::None)
| (Os::Windows, Architecture::Arm, Accelerator::None)
| (Os::Windows, Architecture::Arm64, Accelerator::None)
| (Os::Linux, Architecture::X86_64, Accelerator::None)
| (Os::MacOS, Architecture::Arm64, Accelerator::None) => format!("{}-{}", self.os.as_onnx_str(), self.arch.as_onnx_str()).into(),
// for some reason, arm64/Linux uses `aarch64` instead of `arm64`
(Os::Linux, Architecture::Arm64, Accelerator::None) => format!("{}-{}", self.os.as_onnx_str(), "aarch64").into(),
// for another odd reason, x64/macOS uses `x86_64` instead of `x64`
(Os::MacOS, Architecture::X86_64, Accelerator::None) => format!("{}-{}", self.os.as_onnx_str(), "x86_64").into(),
(Os::Windows, Architecture::X86_64, Accelerator::Gpu) | (Os::Linux, Architecture::X86_64, Accelerator::Gpu) => {
format!("{}-{}-{}", self.os.as_onnx_str(), self.arch.as_onnx_str(), self.accelerator.as_onnx_str()).into()
}
_ => panic!(
"Microsoft does not provide ONNX Runtime downloads for triplet: {}-{}-{}; you may have to use the `system` strategy instead",
self.os.as_onnx_str(),
self.arch.as_onnx_str(),
self.accelerator.as_onnx_str()
)
}
}
}
fn prebuilt_onnx_url() -> (PathBuf, String) {
let accelerator = if cfg!(feature = "cuda") || cfg!(feature = "tensorrt") {
Accelerator::Gpu
} else {
Accelerator::None
};
let triplet = Triplet {
os: env::var("CARGO_CFG_TARGET_OS").expect("unable to get target OS").parse().unwrap(),
arch: env::var("CARGO_CFG_TARGET_ARCH").expect("unable to get target arch").parse().unwrap(),
accelerator
};
let prebuilt_archive = format!("onnxruntime-{}-{}.{}", triplet.as_onnx_str(), ORT_VERSION, triplet.os.archive_extension());
let prebuilt_url = format!("{}/v{}/{}", ORT_RELEASE_BASE_URL, ORT_VERSION, prebuilt_archive);
(PathBuf::from(prebuilt_archive), prebuilt_url)
}
fn prebuilt_protoc_url() -> (PathBuf, String) {
let host_platform = if cfg!(target_os = "windows") {
std::string::String::from("win32")
} else if cfg!(target_os = "macos") {
format!(
"osx-{}",
if cfg!(target_arch = "x86_64") {
"x86_64"
} else if cfg!(target_arch = "x86") {
"x86"
} else {
panic!("protoc does not have prebuilt binaries for darwin arm64 yet")
}
)
} else {
format!("linux-{}", if cfg!(target_arch = "x86_64") { "x86_64" } else { "x86_32" })
};
let prebuilt_archive = format!("protoc-{}-{}.zip", PROTOBUF_VERSION, host_platform);
let prebuilt_url = format!("{}/v{}/{}", PROTOBUF_RELEASE_BASE_URL, PROTOBUF_VERSION, prebuilt_archive);
(PathBuf::from(prebuilt_archive), prebuilt_url)
}
fn download<P>(source_url: &str, target_file: P)
where
P: AsRef<Path>
{
let resp = ureq::get(source_url)
.timeout(std::time::Duration::from_secs(300))
.call()
.unwrap_or_else(|err| panic!("ERROR: Failed to download {}: {:?}", source_url, err));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
let mut reader = resp.into_reader();
// FIXME: Save directly to the file
let mut buffer = vec![];
let read_len = reader.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer.len(), len);
assert_eq!(buffer.len(), read_len);
let f = fs::File::create(&target_file).unwrap();
let mut writer = io::BufWriter::new(f);
writer.write_all(&buffer).unwrap();
}
fn extract_archive(filename: &Path, output: &Path) {
match filename.extension().map(|e| e.to_str()) {
Some(Some("zip")) => extract_zip(filename, output),
#[cfg(not(target_os = "windows"))]
Some(Some("tgz")) => extract_tgz(filename, output),
_ => unimplemented!()
}
}
#[cfg(not(target_os = "windows"))]
fn extract_tgz(filename: &Path, output: &Path) {
let file = fs::File::open(&filename).unwrap();
let buf = io::BufReader::new(file);
let tar = flate2::read::GzDecoder::new(buf);
let mut archive = tar::Archive::new(tar);
archive.unpack(output).unwrap();
}
fn extract_zip(filename: &Path, outpath: &Path) {
let file = fs::File::open(filename).unwrap();
let buf = io::BufReader::new(file);
let mut archive = zip::ZipArchive::new(buf).unwrap();
for i in 0..archive.len() {
let mut file = archive.by_index(i).unwrap();
#[allow(deprecated)]
let outpath = outpath.join(file.enclosed_name().unwrap());
if !file.name().ends_with('/') {
println!("File {} extracted to \"{}\" ({} bytes)", i, outpath.as_path().display(), file.size());
if let Some(p) = outpath.parent() {
if !p.exists() {
fs::create_dir_all(p).unwrap();
}
}
let mut outfile = fs::File::create(&outpath).unwrap();
io::copy(&mut file, &mut outfile).unwrap();
}
}
}
fn copy_libraries(lib_dir: &Path, out_dir: &Path) {
// get the target directory - we need to place the dlls next to the executable so they can be properly loaded by windows
let out_dir = out_dir.parent().unwrap().parent().unwrap().parent().unwrap();
let lib_files = fs::read_dir(lib_dir).unwrap();
for lib_file in lib_files.filter(|e| {
e.as_ref()
.ok()
.map(|e| {
e.file_type().map(|e| e.is_file()).unwrap_or(false)
&& [".dll", ".so", ".dylib"]
.into_iter()
.any(|v| e.path().into_os_string().into_string().unwrap().contains(v))
})
.unwrap_or(false)
}) {
let lib_file = lib_file.unwrap();
let lib_path = lib_file.path();
let lib_name = lib_path.file_name().unwrap();
let out_path = out_dir.join(lib_name);
if !out_path.exists() {
fs::copy(&lib_path, out_path).unwrap();
}
}
}
fn prepare_libort_dir() -> (PathBuf, bool) {
let strategy = env::var(ORT_ENV_STRATEGY);
println!("[ort] strategy: {:?}", strategy.as_ref().map(String::as_str).unwrap_or_else(|_| "unknown"));
let target = env::var("TARGET").unwrap();
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
if target_arch.eq_ignore_ascii_case("aarch64") {
incompatible_providers![cuda, openvino, vitis_ai, tensorrt, migraphx, rocm];
} else if target_arch.eq_ignore_ascii_case("x86_64") {
incompatible_providers![vitis_ai, acl, armnn];
} else {
panic!("unsupported target architecture: {}", target_arch);
}
if target.contains("macos") {
incompatible_providers![cuda, openvino, tensorrt, directml, winml];
} else if target.contains("windows") {
incompatible_providers![coreml, vitis_ai, acl, armnn];
} else {
incompatible_providers![coreml, vitis_ai, directml, winml];
}
match strategy
.as_ref()
.map(String::as_str)
.unwrap_or_else(|_| if cfg!(feature = "prefer-compile-strategy") { "compile" } else { "download" })
{
"download" => {
if target.contains("macos") {
incompatible_providers![cuda, onednn, openvino, openmp, vitis_ai, tvm, tensorrt, migraphx, directml, winml, acl, armnn, rocm];
} else {
incompatible_providers![onednn, coreml, openvino, openmp, vitis_ai, tvm, migraphx, directml, winml, acl, armnn, rocm];
}
let (prebuilt_archive, prebuilt_url) = prebuilt_onnx_url();
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let extract_dir = out_dir.join(ORT_EXTRACT_DIR);
let downloaded_file = out_dir.join(&prebuilt_archive);
println!("cargo:rerun-if-changed={}", downloaded_file.display());
if !downloaded_file.exists() {
fs::create_dir_all(&out_dir).unwrap();
download(&prebuilt_url, &downloaded_file);
}
if !extract_dir.exists() {
extract_archive(&downloaded_file, &extract_dir);
}
let lib_dir = extract_dir.join(prebuilt_archive.file_stem().unwrap());
#[cfg(feature = "copy-dylibs")]
{
copy_libraries(&lib_dir.join("lib"), &out_dir);
}
(lib_dir, true)
}
"system" => {
let lib_dir = PathBuf::from(env::var(ORT_ENV_SYSTEM_LIB_LOCATION).expect("[ort] system strategy requires ORT_LIB_LOCATION env var to be set"));
#[cfg(feature = "copy-dylibs")]
{
copy_libraries(&lib_dir.join("lib"), &PathBuf::from(env::var("OUT_DIR").unwrap()));
}
(lib_dir, true)
}
"compile" => {
use std::process::Command;
let target = env::var("TARGET").unwrap();
if target.contains("macos") && !cfg!(target_os = "darwin") && env::var(ORT_ENV_CMAKE_PROGRAM).is_err() {
panic!("[ort] cross-compiling for macOS with the `compile` strategy requires `{}` to be set", ORT_ENV_CMAKE_PROGRAM);
}
let cmake = env::var(ORT_ENV_CMAKE_PROGRAM).unwrap_or_else(|_| "cmake".to_string());
let python = env::var(ORT_ENV_PYTHON_PROGRAM).unwrap_or_else(|_| {
if Command::new("python").arg("--version").status().unwrap().success() {
"python".to_string()
} else {
"python3".to_string()
}
});
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let required_cmds: &[&str] = &[&cmake, "python", "git"];
for cmd in required_cmds {
if Command::new(cmd).output().is_err() {
panic!("[ort] compile strategy requires `{}` to be installed", cmd);
}
}
println!("[ort] assuming C/C++ compilers are available");
Command::new("git")
.args([
"clone",
"--depth",
"1",
"--single-branch",
"--branch",
&format!("v{}", ORT_VERSION),
"--shallow-submodules",
"--recursive",
ORT_GIT_REPO,
ORT_GIT_DIR
])
.current_dir(&out_dir)
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.expect("failed to clone ORT repo");
// download prebuilt protoc binary
let (protoc_archive, protoc_url) = prebuilt_protoc_url();
let protoc_dir = out_dir.join(PROTOBUF_EXTRACT_DIR);
let protoc_archive_file = out_dir.join(protoc_archive);
println!("cargo:rerun-if-changed={}", protoc_archive_file.display());
if !protoc_archive_file.exists() {
download(&protoc_url, &protoc_archive_file);
}
if !protoc_dir.exists() {
extract_archive(&protoc_archive_file, &protoc_dir);
}
let protoc_file = if cfg!(target_os = "windows") { "protoc.exe" } else { "protoc" };
let protoc_file = protoc_dir.join("bin").join(protoc_file);
Command::new(protoc_file)
.args(["--help"])
.current_dir(&out_dir)
.stdout(Stdio::null())
.status()
.expect("error running `protoc --help`");
let root = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let _cmake_toolchain = env::var(ORT_ENV_CMAKE_TOOLCHAIN).map(PathBuf::from).unwrap_or(
if cfg!(target_os = "linux") && target.contains("aarch64") && target.contains("linux") {
root.join("toolchains").join("default-aarch64-linux-gnu.cmake")
} else if cfg!(target_os = "linux") && target.contains("aarch64") && target.contains("windows") {
root.join("toolchains").join("default-aarch64-w64-mingw32.cmake")
} else if cfg!(target_os = "linux") && target.contains("x86_64") && target.contains("windows") {
root.join("toolchains").join("default-x86_64-w64-mingw32.cmake")
} else {
PathBuf::new()
}
);
if cfg!(target_os = "linux") && target.contains("windows") && target.contains("aarch64") {
println!("[ort] detected cross compilation to Windows arm64, default toolchain will make bad assumptions.");
}
let mut command = Command::new(python);
command
.current_dir(&out_dir.join(ORT_GIT_DIR))
.stdout(Stdio::null())
.stderr(Stdio::inherit());
// note: --parallel will probably break something... parallel build *while* doing another parallel build (cargo)?
let mut build_args = vec!["tools/ci_build/build.py", "--build", "--update", "--parallel", "--skip_tests", "--skip_submodule_sync"];
let config = if cfg!(debug_assertions) {
"Debug"
} else if cfg!(feature = "minimal-build") {
"MinSizeRel"
} else {
"Release"
};
build_args.push("--config");
build_args.push(config);
if cfg!(feature = "minimal-build") {
build_args.push("--disable_exceptions");
}
build_args.push("--disable_rtti");
if target.contains("windows") {
build_args.push("--disable_memleak_checker");
}
if !cfg!(feature = "compile-static") {
build_args.push("--build_shared_lib");
} else {
build_args.push("--enable_msvc_static_runtime");
}
// onnxruntime will still build tests when --skip_tests is enabled, this filters out most of them
// this "fixes" compilation on alpine: https://github.com/microsoft/onnxruntime/issues/9155
// but causes other compilation errors: https://github.com/microsoft/onnxruntime/issues/7571
build_args.push("--cmake_extra_defines");
build_args.push("onnxruntime_BUILD_UNIT_TESTS=0");
// if we can use ninja on windows, great! let's use it!
// note that ninja + clang on windows is a total shitstorm so it's disabled for now
#[cfg(target_os = "windows")]
if Command::new("ninja").arg("--version").status().unwrap().success() && !Command::new("clang-cl").arg("--version").status().unwrap().success() {
build_args.push("--cmake_generator=Ninja");
} else {
// fuck
use vswhom::VsFindResult;
let vs_find_result = VsFindResult::search();
match vs_find_result {
Some(VsFindResult { vs_exe_path: Some(vs_exe_path), .. }) => {
let vs_exe_path = vs_exe_path.to_string_lossy();
// the one sane thing about visual studio is that the version numbers are somewhat predictable...
if vs_exe_path.contains("14.1") {
build_args.push("--cmake_generator=Visual Studio 15 2017");
} else if vs_exe_path.contains("14.2") {
build_args.push("--cmake_generator=Visual Studio 16 2019");
} else if vs_exe_path.contains("14.3") {
build_args.push("--cmake_generator=Visual Studio 17 2022");
}
}
Some(VsFindResult { vs_exe_path: None, .. }) | None => panic!("[ort] unable to find Visual Studio installation")
};
}
build_args.push("--build_dir=build");
command.args(build_args);
let code = command.status().expect("failed to run build script");
assert!(code.success(), "failed to build ONNX Runtime");
let lib_dir = out_dir.join(ORT_GIT_DIR).join("build").join(config);
let lib_dir = if cfg!(target_os = "windows") { lib_dir.join(config) } else { lib_dir };
for lib in &["common", "flatbuffers", "framework", "graph", "mlas", "optimizer", "providers", "session", "util"] {
let lib_path = lib_dir.join(if cfg!(target_os = "windows") {
format!("onnxruntime_{}.lib", lib)
} else {
format!("libonnxruntime_{}.a", lib)
});
// sanity check, just make sure the library exists before we try to link to it
if lib_path.exists() {
println!("cargo:rustc-link-lib=static=onnxruntime_{}", lib);
} else {
panic!("[ort] unable to find ONNX Runtime library: {}", lib_path.display());
}
}
// also need to link to onnx.lib and onnx_proto.lib
let external_lib_dir = lib_dir.parent().unwrap().join("external").join("onnx");
let external_lib_dir = if cfg!(target_os = "windows") { external_lib_dir.join(config) } else { external_lib_dir };
println!("cargo:rustc-link-search=native={}", external_lib_dir.display());
println!("cargo:rustc-link-lib=static=onnx");
println!("cargo:rustc-link-lib=static=onnx_proto");
println!("cargo:rustc-link-lib=onnxruntime_providers_shared");
println!("cargo:rustc-link-search=native={}", lib_dir.display());
(out_dir, false)
}
_ => panic!("[ort] unknown strategy: {} (valid options are `download` or `system`)", strategy.unwrap_or_else(|_| "unknown".to_string()))
}
}
#[cfg(not(feature = "generate-bindings"))]
fn generate_bindings(_include_dir: &Path) {
println!("[ort] bindings not generated automatically; using committed bindings instead.");
println!("[ort] enable the `generate-bindings` feature to generate fresh bindings.");
}
#[cfg(feature = "generate-bindings")]
fn generate_bindings(include_dir: &Path) {
let clang_args = &[
format!("-I{}", include_dir.display()),
format!("-I{}", include_dir.join("onnxruntime").join("core").join("session").display())
];
println!("cargo:rerun-if-changed=src/wrapper.h");
let bindings = bindgen::Builder::default()
.header("src/wrapper.h")
.clang_args(clang_args)
// Tell cargo to invalidate the built crate whenever any of the included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
// Set `size_t` to be translated to `usize` for win32 compatibility.
.size_t_is_usize(env::var("CARGO_CFG_TARGET_ARCH").unwrap().contains("x86"))
// Format using rustfmt
.rustfmt_bindings(true)
.rustified_enum("*")
.generate()
.expect("Unable to generate bindings");
// Write the bindings to (source controlled) src/onnx/bindings/<os>/<arch>/bindings.rs
let generated_file = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
.join("src")
.join("bindings")
.join(env::var("CARGO_CFG_TARGET_OS").unwrap())
.join(env::var("CARGO_CFG_TARGET_ARCH").unwrap())
.join("bindings.rs");
println!("cargo:rerun-if-changed={:?}", generated_file);
fs::create_dir_all(generated_file.parent().unwrap()).unwrap();
bindings.write_to_file(&generated_file).expect("Couldn't write bindings!");
}
#[cfg(feature = "disable-build-script")]
fn main() {}
#[cfg(not(feature = "disable-build-script"))]
fn main() {
let (install_dir, needs_link) = prepare_libort_dir();
let include_dir = install_dir.join("include");
let lib_dir = install_dir.join("lib");
if needs_link {
println!("cargo:rustc-link-lib=onnxruntime");
println!("cargo:rustc-link-search=native={}", lib_dir.display());
}
println!("cargo:rerun-if-env-changed={}", ORT_ENV_STRATEGY);
println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION);
generate_bindings(&include_dir);
}

67
examples/gpt.rs Normal file
View File

@@ -0,0 +1,67 @@
use std::sync::Arc;
use ndarray::{array, concatenate, s, Array1, Axis};
use ort::{
download::language::{machine_comprehension::GPT2, MachineComprehension},
tensor::{DynOrtTensor, FromArray, InputTensor, OrtOwnedTensor},
Environment, ExecutionProvider, GraphOptimizationLevel, LoggingLevel, OrtResult, SessionBuilder
};
use rand::Rng;
use rust_tokenizers::{tokenizer::Gpt2Tokenizer, tokenizer::Tokenizer};
const GEN_TOKENS: i32 = 45;
const TOP_K: usize = 5;
fn main() -> OrtResult<()> {
let mut rng = rand::thread_rng();
let environment = Arc::new(
Environment::builder()
.with_name("GPT-2")
.with_log_level(LoggingLevel::Warning)
.with_execution_providers([ExecutionProvider::cuda_if_available()])
.build()?
);
let mut session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(MachineComprehension::GPT2(GPT2::GPT2LmHead))?;
let tokenizer = Gpt2Tokenizer::from_file("tests/data/gpt2/vocab.json", "tests/data/gpt2/merges.txt", false).unwrap();
let tokens = tokenizer
.encode(
"The corsac fox (Vulpes corsac), also known simply as a corsac, is a medium-sized fox found in",
None,
128,
&rust_tokenizers::tokenizer::TruncationStrategy::LongestFirst,
0
)
.token_ids;
let tokens = &mut Array1::from_iter(tokens.iter().cloned());
for _ in 0..GEN_TOKENS {
let n_tokens = &tokens.shape()[0];
let array = tokens.clone().insert_axis(Axis(0)).into_shape((1, 1, *n_tokens)).unwrap();
let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = session.run([InputTensor::from_array(array.into_dyn())])?;
let generated_tokens: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
let generated_tokens = generated_tokens.view();
let probabilities = &mut generated_tokens
.slice(s![0, 0, -1, ..])
.insert_axis(Axis(0))
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
let token = probabilities[rng.gen_range(0..=TOP_K)].0;
*tokens = concatenate![Axis(0), *tokens, array![token.try_into().unwrap()]];
let sentence = tokenizer.decode(&tokens.iter().copied().collect::<Vec<_>>(), true, true);
println!("{}", sentence);
}
Ok(())
}

34
rustfmt.toml Normal file
View File

@@ -0,0 +1,34 @@
edition = "2021"
version = "Two"
unstable_features = true
max_width = 160
hard_tabs = true
tab_spaces = 4
newline_style = "Unix"
wrap_comments = true
format_code_in_doc_comments = true
comment_width = 120
doc_comment_code_block_width = 120
normalize_comments = true
use_small_heuristics = "Off"
fn_call_width = 140
attr_fn_like_width = 112
struct_lit_width = 36
struct_variant_width = 60
array_width = 120
chain_width = 90
single_line_if_else_max_width = 96
reorder_imports = true
group_imports = "StdExternalCrate"
reorder_modules = true
trailing_comma = "Never"
match_block_trailing_comma = false
format_macro_bodies = true
use_try_shorthand = true
use_field_init_shorthand = true
merge_derives = true
force_explicit_abi = true

17
src/bindings/bindings.rs Normal file
View File

@@ -0,0 +1,17 @@
#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/linux/x86_64/bindings.rs"));
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/linux/aarch64/bindings.rs"));
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/macos/x86_64/bindings.rs"));
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/macos/aarch64/bindings.rs"));
#[cfg(all(target_os = "windows", target_arch = "x86_64"))]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/windows/x86_64/bindings.rs"));
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/windows/aarch64/bindings.rs"));

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

83
src/download.rs Normal file
View File

@@ -0,0 +1,83 @@
#[cfg(feature = "fetch-models")]
use std::{
fs, io,
path::{Path, PathBuf},
time::Duration
};
#[cfg(feature = "fetch-models")]
use tracing::info;
#[cfg(feature = "fetch-models")]
use crate::error::{OrtDownloadError, OrtResult};
pub mod language;
pub mod vision;
/// Available pre-trained models to download from the [ONNX Model Zoo](https://github.com/onnx/models).
#[derive(Debug, Clone)]
pub enum OnnxModel {
/// Computer vision models
Vision(vision::Vision),
/// Language models
Language(language::Language)
}
trait ModelUrl {
fn fetch_url(&self) -> &'static str;
}
impl ModelUrl for OnnxModel {
fn fetch_url(&self) -> &'static str {
match self {
OnnxModel::Vision(model) => model.fetch_url(),
OnnxModel::Language(model) => model.fetch_url()
}
}
}
impl OnnxModel {
#[cfg(feature = "fetch-models")]
#[tracing::instrument]
pub(crate) fn download_to<P>(&self, download_dir: P) -> OrtResult<PathBuf>
where
P: AsRef<Path> + std::fmt::Debug
{
let url = self.fetch_url();
let model_filename = PathBuf::from(url.split('/').last().unwrap());
let model_filepath = download_dir.as_ref().join(model_filename);
if model_filepath.exists() {
info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
let resp = ureq::get(url)
.timeout(Duration::from_secs(180))
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::FetchError)?;
assert!(resp.has("Content-Length"));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_reader();
let f = fs::File::create(&model_filepath).unwrap();
let mut writer = io::BufWriter::new(f);
let bytes_io_count = io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
if bytes_io_count == len as u64 {
Ok(model_filepath)
} else {
Err(OrtDownloadError::CopyError {
expected: len as u64,
io: bytes_io_count
}
.into())
}
}
}
}

18
src/download/language.rs Normal file
View File

@@ -0,0 +1,18 @@
use super::ModelUrl;
pub mod machine_comprehension;
pub use machine_comprehension::MachineComprehension;
#[derive(Debug, Clone)]
pub enum Language {
MachineComprehension(MachineComprehension)
}
impl ModelUrl for Language {
fn fetch_url(&self) -> &'static str {
match self {
Language::MachineComprehension(v) => v.fetch_url()
}
}
}

View File

@@ -0,0 +1,81 @@
#![allow(clippy::upper_case_acronyms)]
use crate::download::{language::Language, ModelUrl, OnnxModel};
/// Machine comprehension models.
///
/// A subset of natural language processing models that answer questions about a given context paragraph.
#[derive(Debug, Clone)]
pub enum MachineComprehension {
/// Answers a query about a given context paragraph.
BiDAF,
/// Answers questions based on the context of the given input paragraph.
BERTSquad,
/// Large transformer-based model that predicts sentiment based on given input text.
RoBERTa(RoBERTa),
/// Generates synthetic text samples in response to the model being primed with an arbitrary input.
GPT2(GPT2)
}
/// Large transformer-based model that predicts sentiment based on given input text.
#[derive(Debug, Clone)]
pub enum RoBERTa {
RoBERTaBase,
RoBERTaSequenceClassification
}
/// Generates synthetic text samples in response to the model being primed with an arbitrary input.
#[derive(Debug, Clone)]
pub enum GPT2 {
GPT2,
GPT2LmHead
}
impl ModelUrl for MachineComprehension {
fn fetch_url(&self) -> &'static str {
match self {
MachineComprehension::BiDAF => "https://github.com/onnx/models/raw/main/text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.onnx",
MachineComprehension::BERTSquad => "https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx",
MachineComprehension::RoBERTa(variant) => variant.fetch_url(),
MachineComprehension::GPT2(variant) => variant.fetch_url()
}
}
}
impl ModelUrl for RoBERTa {
fn fetch_url(&self) -> &'static str {
match self {
RoBERTa::RoBERTaBase => "https://github.com/onnx/models/raw/main/text/machine_comprehension/roberta/model/roberta-base-11.onnx",
RoBERTa::RoBERTaSequenceClassification => {
"https://github.com/onnx/models/raw/main/text/machine_comprehension/roberta/model/roberta-sequence-classification-9.onnx"
}
}
}
}
impl ModelUrl for GPT2 {
fn fetch_url(&self) -> &'static str {
match self {
GPT2::GPT2 => "https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx",
GPT2::GPT2LmHead => "https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx"
}
}
}
impl From<MachineComprehension> for OnnxModel {
fn from(model: MachineComprehension) -> Self {
OnnxModel::Language(Language::MachineComprehension(model))
}
}
impl From<RoBERTa> for OnnxModel {
fn from(model: RoBERTa) -> Self {
OnnxModel::Language(Language::MachineComprehension(MachineComprehension::RoBERTa(model)))
}
}
impl From<GPT2> for OnnxModel {
fn from(model: GPT2) -> Self {
OnnxModel::Language(Language::MachineComprehension(MachineComprehension::GPT2(model)))
}
}

34
src/download/vision.rs Normal file
View File

@@ -0,0 +1,34 @@
use super::ModelUrl;
pub mod body_face_gesture_analysis;
pub mod domain_based_image_classification;
pub mod image_classification;
pub mod image_manipulation;
pub mod object_detection_image_segmentation;
pub use body_face_gesture_analysis::BodyFaceGestureAnalysis;
pub use domain_based_image_classification::DomainBasedImageClassification;
pub use image_classification::ImageClassification;
pub use image_manipulation::ImageManipulation;
pub use object_detection_image_segmentation::ObjectDetectionImageSegmentation;
#[derive(Debug, Clone)]
pub enum Vision {
BodyFaceGestureAnalysis(BodyFaceGestureAnalysis),
DomainBasedImageClassification(DomainBasedImageClassification),
ImageClassification(ImageClassification),
ImageManipulation(ImageManipulation),
ObjectDetectionImageSegmentation(ObjectDetectionImageSegmentation)
}
impl ModelUrl for Vision {
fn fetch_url(&self) -> &'static str {
match self {
Vision::DomainBasedImageClassification(v) => v.fetch_url(),
Vision::ImageClassification(v) => v.fetch_url(),
Vision::ImageManipulation(v) => v.fetch_url(),
Vision::ObjectDetectionImageSegmentation(v) => v.fetch_url(),
Vision::BodyFaceGestureAnalysis(v) => v.fetch_url()
}
}
}

View File

@@ -0,0 +1,27 @@
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
#[derive(Debug, Clone)]
pub enum BodyFaceGestureAnalysis {
/// A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for
/// input face images.
ArcFace,
/// Deep CNN for emotion recognition trained on images of faces.
EmotionFerPlus
}
impl ModelUrl for BodyFaceGestureAnalysis {
fn fetch_url(&self) -> &'static str {
match self {
BodyFaceGestureAnalysis::ArcFace => "https://github.com/onnx/models/raw/main/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx",
BodyFaceGestureAnalysis::EmotionFerPlus => {
"https://github.com/onnx/models/raw/main/vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.onnx"
}
}
}
}
impl From<BodyFaceGestureAnalysis> for OnnxModel {
fn from(model: BodyFaceGestureAnalysis) -> Self {
OnnxModel::Vision(Vision::BodyFaceGestureAnalysis(model))
}
}

View File

@@ -0,0 +1,21 @@
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
#[derive(Debug, Clone)]
pub enum DomainBasedImageClassification {
/// Handwritten digit prediction using CNN.
Mnist
}
impl ModelUrl for DomainBasedImageClassification {
fn fetch_url(&self) -> &'static str {
match self {
DomainBasedImageClassification::Mnist => "https://github.com/onnx/models/raw/main/vision/classification/mnist/model/mnist-8.onnx"
}
}
}
impl From<DomainBasedImageClassification> for OnnxModel {
fn from(model: DomainBasedImageClassification) -> Self {
OnnxModel::Vision(Vision::DomainBasedImageClassification(model))
}
}

View File

@@ -0,0 +1,236 @@
#![allow(clippy::upper_case_acronyms)]
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
#[derive(Debug, Clone)]
pub enum ImageClassification {
/// Image classification aimed for mobile targets.
///
/// > MobileNet models perform image classification - they take images as input and classify the major
/// > object in the image into a set of pre-defined classes. They are trained on ImageNet dataset which
/// > contains images from 1000 classes. MobileNet models are also very efficient in terms of speed and
/// > size and hence are ideal for embedded and mobile applications.
MobileNet,
/// Image classification, trained on ImageNet with 1000 classes.
///
/// > ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when
/// > high accuracy of classification is required.
ResNet(ResNet),
/// A small CNN with AlexNet level accuracy on ImageNet with 50x fewer parameters.
///
/// > SqueezeNet is a small CNN which achieves AlexNet level accuracy on ImageNet with 50x fewer parameters.
/// > SqueezeNet requires less communication across servers during distributed training, less bandwidth to
/// > export a new model from the cloud to an autonomous car and more feasible to deploy on FPGAs and other
/// > hardware with limited memory.
SqueezeNet,
/// Image classification, trained on ImageNet with 1000 classes.
///
/// > VGG models provide very high accuracies but at the cost of increased model sizes. They are ideal for
/// > cases when high accuracy of classification is essential and there are limited constraints on model sizes.
Vgg(Vgg),
/// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition
/// Challenge in 2012.
AlexNet,
/// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition
/// Challenge in 2014.
GoogleNet,
/// Variant of AlexNet, it's the name of a convolutional neural network for classification, which competed in the
/// ImageNet Large Scale Visual Recognition Challenge in 2012.
CaffeNet,
/// Convolutional neural network for detection.
///
/// > This model was made by transplanting the R-CNN SVM classifiers into a fc-rcnn classification layer.
RcnnIlsvrc13,
/// Convolutional neural network for classification.
DenseNet121,
/// Google's Inception
Inception(InceptionVersion),
/// Computationally efficient CNN architecture designed specifically for mobile devices with very limited computing
/// power.
ShuffleNet(ShuffleNetVersion),
/// Deep convolutional networks for classification.
///
/// > This model's 4th layer has 512 maps instead of 1024 maps mentioned in the paper.
ZFNet512,
/// Image classification model that achieves state-of-the-art accuracy.
///
/// > It is designed to run on mobile CPU, GPU, and EdgeTPU devices, allowing for applications on mobile and loT,
/// where computational resources are limited.
EfficientNetLite4
}
#[derive(Debug, Clone)]
pub enum InceptionVersion {
V1,
V2
}
#[derive(Debug, Clone)]
pub enum ResNet {
V1(ResNetV1),
V2(ResNetV2)
}
#[derive(Debug, Clone)]
pub enum ResNetV1 {
ResNet18,
ResNet34,
ResNet50,
ResNet101,
ResNet152
}
#[derive(Debug, Clone)]
pub enum ResNetV2 {
ResNet18,
ResNet34,
ResNet50,
ResNet101,
ResNet152
}
#[derive(Debug, Clone)]
pub enum Vgg {
/// VGG with 16 convolutional layers
Vgg16,
/// VGG with 16 convolutional layers, with batch normalization applied after each convolutional layer.
///
/// The batch normalization leads to better convergence and slightly better accuracies.
Vgg16Bn,
/// VGG with 19 convolutional layers
Vgg19,
/// VGG with 19 convolutional layers, with batch normalization applied after each convolutional layer.
///
/// The batch normalization leads to better convergence and slightly better accuracies.
Vgg19Bn
}
/// Computationally efficient CNN architecture designed specifically for mobile devices with very limited computing
/// power.
#[derive(Debug, Clone)]
pub enum ShuffleNetVersion {
V1,
/// ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff
/// used for image classification.
V2
}
impl ModelUrl for ImageClassification {
fn fetch_url(&self) -> &'static str {
match self {
ImageClassification::MobileNet => "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx",
ImageClassification::SqueezeNet => "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx",
ImageClassification::Inception(version) => version.fetch_url(),
ImageClassification::ResNet(version) => version.fetch_url(),
ImageClassification::Vgg(variant) => variant.fetch_url(),
ImageClassification::AlexNet => "https://github.com/onnx/models/raw/main/vision/classification/alexnet/model/bvlcalexnet-9.onnx",
ImageClassification::GoogleNet => {
"https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/googlenet/model/googlenet-9.onnx"
}
ImageClassification::CaffeNet => "https://github.com/onnx/models/raw/main/vision/classification/caffenet/model/caffenet-9.onnx",
ImageClassification::RcnnIlsvrc13 => "https://github.com/onnx/models/raw/main/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.onnx",
ImageClassification::DenseNet121 => "https://github.com/onnx/models/raw/main/vision/classification/densenet-121/model/densenet-9.onnx",
ImageClassification::ShuffleNet(version) => version.fetch_url(),
ImageClassification::ZFNet512 => "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx",
ImageClassification::EfficientNetLite4 => {
"https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4.onnx"
}
}
}
}
impl ModelUrl for InceptionVersion {
fn fetch_url(&self) -> &'static str {
match self {
InceptionVersion::V1 => {
"https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-9.onnx"
}
InceptionVersion::V2 => {
"https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx"
}
}
}
}
impl ModelUrl for ResNet {
fn fetch_url(&self) -> &'static str {
match self {
ResNet::V1(variant) => variant.fetch_url(),
ResNet::V2(variant) => variant.fetch_url()
}
}
}
impl ModelUrl for ResNetV1 {
fn fetch_url(&self) -> &'static str {
match self {
ResNetV1::ResNet18 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx",
ResNetV1::ResNet34 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet34-v1-7.onnx",
ResNetV1::ResNet50 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v1-7.onnx",
ResNetV1::ResNet101 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet101-v1-7.onnx",
ResNetV1::ResNet152 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet152-v1-7.onnx"
}
}
}
impl ModelUrl for ResNetV2 {
fn fetch_url(&self) -> &'static str {
match self {
ResNetV2::ResNet18 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx",
ResNetV2::ResNet34 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet34-v2-7.onnx",
ResNetV2::ResNet50 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v2-7.onnx",
ResNetV2::ResNet101 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet101-v2-7.onnx",
ResNetV2::ResNet152 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet152-v2-7.onnx"
}
}
}
impl ModelUrl for Vgg {
fn fetch_url(&self) -> &'static str {
match self {
Vgg::Vgg16 => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg16-7.onnx",
Vgg::Vgg16Bn => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg16-bn-7.onnx",
Vgg::Vgg19 => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg19-7.onnx",
Vgg::Vgg19Bn => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg19-bn-7.onnx"
}
}
}
impl ModelUrl for ShuffleNetVersion {
fn fetch_url(&self) -> &'static str {
match self {
ShuffleNetVersion::V1 => "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx",
ShuffleNetVersion::V2 => "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx"
}
}
}
impl From<ImageClassification> for OnnxModel {
fn from(model: ImageClassification) -> Self {
OnnxModel::Vision(Vision::ImageClassification(model))
}
}
impl From<ResNet> for OnnxModel {
fn from(variant: ResNet) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::ResNet(variant)))
}
}
impl From<Vgg> for OnnxModel {
fn from(variant: Vgg) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::Vgg(variant)))
}
}
impl From<InceptionVersion> for OnnxModel {
fn from(variant: InceptionVersion) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::Inception(variant)))
}
}
impl From<ShuffleNetVersion> for OnnxModel {
fn from(variant: ShuffleNetVersion) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::ShuffleNet(variant)))
}
}

View File

@@ -0,0 +1,68 @@
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
/// Image Manipulation
///
/// > Image manipulation models use neural networks to transform input images to modified output images. Some
/// > popular models in this category involve style transfer or enhancing images by increasing resolution.
#[derive(Debug, Clone)]
pub enum ImageManipulation {
/// Super Resolution
///
/// > The Super Resolution machine learning model sharpens and upscales the input image to refine the
/// > details and improve quality.
SuperResolution,
/// Fast Neural Style Transfer
///
/// > This artistic style transfer model mixes the content of an image with the style of another image.
/// > Examples of the styles can be seen
/// > [in this PyTorch example](https://github.com/pytorch/examples/tree/master/fast_neural_style#models).
FastNeuralStyleTransfer(FastNeuralStyleTransferStyle)
}
#[derive(Debug, Clone)]
pub enum FastNeuralStyleTransferStyle {
Mosaic,
Candy,
RainPrincess,
Udnie,
Pointilism
}
impl ModelUrl for ImageManipulation {
fn fetch_url(&self) -> &'static str {
match self {
ImageManipulation::SuperResolution => {
"https://github.com/onnx/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx"
}
ImageManipulation::FastNeuralStyleTransfer(style) => style.fetch_url()
}
}
}
impl ModelUrl for FastNeuralStyleTransferStyle {
fn fetch_url(&self) -> &'static str {
match self {
FastNeuralStyleTransferStyle::Mosaic => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/mosaic-9.onnx",
FastNeuralStyleTransferStyle::Candy => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/candy-9.onnx",
FastNeuralStyleTransferStyle::RainPrincess => {
"https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx"
}
FastNeuralStyleTransferStyle::Udnie => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/udnie-9.onnx",
FastNeuralStyleTransferStyle::Pointilism => {
"https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/pointilism-9.onnx"
}
}
}
}
impl From<ImageManipulation> for OnnxModel {
fn from(model: ImageManipulation) -> Self {
OnnxModel::Vision(Vision::ImageManipulation(model))
}
}
impl From<FastNeuralStyleTransferStyle> for OnnxModel {
fn from(style: FastNeuralStyleTransferStyle) -> Self {
OnnxModel::Vision(Vision::ImageManipulation(ImageManipulation::FastNeuralStyleTransfer(style)))
}
}

View File

@@ -0,0 +1,96 @@
#![allow(clippy::upper_case_acronyms)]
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
/// Object Detection & Image Segmentation
///
/// > Object detection models detect the presence of multiple objects in an image and segment out areas of the
/// > image where the objects are detected. Semantic segmentation models partition an input image by labeling each pixel
/// > into a set of pre-defined categories.
#[derive(Debug, Clone)]
pub enum ObjectDetectionImageSegmentation {
/// A real-time CNN for object detection that detects 20 different classes. A smaller version of the
/// more complex full YOLOv2 network.
TinyYoloV2,
/// Single Stage Detector: real-time CNN for object detection that detects 80 different classes.
Ssd,
/// A variant of MobileNet that uses the Single Shot Detector (SSD) model framework. The model detects 80
/// different object classes and locates up to 10 objects in an image.
SSDMobileNetV1,
/// Increases efficiency from R-CNN by connecting a RPN with a CNN to create a single, unified network for
/// object detection that detects 80 different classes.
FasterRcnn,
/// A real-time neural network for object instance segmentation that detects 80 different classes. Extends
/// Faster R-CNN as each of the 300 elected ROIs go through 3 parallel branches of the network: label
/// prediction, bounding box prediction and mask prediction.
MaskRcnn,
/// A real-time dense detector network for object detection that addresses class imbalance through Focal Loss.
/// RetinaNet is able to match the speed of previous one-stage detectors and defines the state-of-the-art in
/// two-stage detectors (surpassing R-CNN).
RetinaNet,
/// A CNN model for real-time object detection system that can detect over 9000 object categories. It uses a
/// single network evaluation, enabling it to be more than 1000x faster than R-CNN and 100x faster than
/// Faster R-CNN.
YoloV2,
/// A CNN model for real-time object detection system that can detect over 9000 object categories. It uses
/// a single network evaluation, enabling it to be more than 1000x faster than R-CNN and 100x faster than
/// Faster R-CNN. This model is trained with COCO dataset and contains 80 classes.
YoloV2Coco,
/// A deep CNN model for real-time object detection that detects 80 different classes. A little bigger than
/// YOLOv2 but still very fast. As accurate as SSD but 3 times faster.
YoloV3,
/// A smaller version of YOLOv3 model.
TinyYoloV3,
/// Optimizes the speed and accuracy of object detection. Two times faster than EfficientDet. It improves
/// YOLOv3's AP and FPS by 10% and 12%, respectively, with mAP50 of 52.32 on the COCO 2017 dataset and
/// FPS of 41.7 on Tesla 100.
YoloV4,
/// Deep CNN based pixel-wise semantic segmentation model with >80% mIOU (mean Intersection Over Union).
/// Trained on cityscapes dataset, which can be effectively implemented in self driving vehicle systems.
Duc
}
impl ModelUrl for ObjectDetectionImageSegmentation {
fn fetch_url(&self) -> &'static str {
match self {
ObjectDetectionImageSegmentation::TinyYoloV2 => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/tiny-yolov2/model/tinyyolov2-8.onnx"
}
ObjectDetectionImageSegmentation::Ssd => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/ssd/model/ssd-10.onnx",
ObjectDetectionImageSegmentation::SSDMobileNetV1 => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.onnx"
}
ObjectDetectionImageSegmentation::FasterRcnn => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/faster-rcnn/model/FasterRCNN-10.onnx"
}
ObjectDetectionImageSegmentation::MaskRcnn => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.onnx"
}
ObjectDetectionImageSegmentation::RetinaNet => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/retinanet/model/retinanet-9.onnx"
}
ObjectDetectionImageSegmentation::YoloV2 => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov2/model/yolov2-voc-8.onnx"
}
ObjectDetectionImageSegmentation::YoloV2Coco => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov2-coco/model/yolov2-coco-9.onnx"
}
ObjectDetectionImageSegmentation::YoloV3 => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov3/model/yolov3-10.onnx"
}
ObjectDetectionImageSegmentation::TinyYoloV3 => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/tiny-yolov3/model/tiny-yolov3-11.onnx"
}
ObjectDetectionImageSegmentation::YoloV4 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/model/yolov4.onnx",
ObjectDetectionImageSegmentation::Duc => {
"https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/duc/model/ResNet101-DUC-7.onnx"
}
}
}
}
impl From<ObjectDetectionImageSegmentation> for OnnxModel {
fn from(model: ObjectDetectionImageSegmentation) -> Self {
OnnxModel::Vision(Vision::ObjectDetectionImageSegmentation(model))
}
}

351
src/environment.rs Normal file
View File

@@ -0,0 +1,351 @@
use std::{
ffi::CString,
sync::{atomic::AtomicPtr, Arc, Mutex}
};
use lazy_static::lazy_static;
use tracing::{debug, error, warn};
use super::{
custom_logger,
error::{status_to_result, OrtError, OrtResult},
ort, ortsys, sys, ExecutionProvider, LoggingLevel
};
lazy_static! {
static ref G_ENV: Arc<Mutex<EnvironmentSingleton>> = Arc::new(Mutex::new(EnvironmentSingleton {
name: String::from("uninitialized"),
env_ptr: AtomicPtr::new(std::ptr::null_mut())
}));
}
#[derive(Debug)]
struct EnvironmentSingleton {
name: String,
env_ptr: AtomicPtr<sys::OrtEnv>
}
/// An [`Environment`] is the main entry point of the ONNX Runtime.
///
/// Only one ONNX environment can be created per process. A singleton (through `lazy_static!()`) is used to enforce
/// this.
///
/// Once an environment is created, a [`super::Session`] can be obtained from it.
///
/// **NOTE**: While the [`Environment`] constructor takes a `name` parameter to name the environment, only the first
/// name will be considered if many environments are created.
///
/// # Example
///
/// ```no_run
/// # use std::error::Error;
/// # use ort::{Environment, LoggingLevel};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// let environment = Environment::builder().with_name("test").with_log_level(LoggingLevel::Verbose).build()?;
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct Environment {
env: Arc<Mutex<EnvironmentSingleton>>,
pub(crate) execution_providers: Vec<ExecutionProvider>
}
impl Environment {
/// Create a new environment builder using default values
/// (name: `default`, log level: [`LoggingLevel::Warning`])
pub fn builder() -> EnvBuilder {
EnvBuilder {
name: "default".into(),
log_level: LoggingLevel::Warning,
execution_providers: Vec::new()
}
}
/// Return the name of the current environment
pub fn name(&self) -> String {
self.env.lock().unwrap().name.to_string()
}
pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
*self.env.lock().unwrap().env_ptr.get_mut()
}
fn new(name: String, log_level: LoggingLevel, execution_providers: Vec<ExecutionProvider>) -> OrtResult<Environment> {
// NOTE: Because 'G_ENV' is a lazy_static, locking it will, initially, create
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
// Cloning it to embed it inside the 'Environment' to return
// will thus increase the strong count to 2.
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one.");
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(name.clone()).unwrap();
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, log_level.into(), cname.as_ptr(), &mut env_ptr) };
status_to_result(status).map_err(OrtError::CreateEnvironment)?;
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created.");
*g_env_ptr = env_ptr;
environment_guard.name = name;
// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one lazy_static 'G_ENV'
// * one inside the 'Environment' returned
Ok(Environment {
env: G_ENV.clone(),
execution_providers
})
} else {
warn!(
name = environment_guard.name.as_str(),
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
"Environment already initialized, reusing it.",
);
// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
// will be 2:
// * one lazy_static 'G_ENV'
// * one inside the 'Environment' returned
Ok(Environment {
env: G_ENV.clone(),
execution_providers
})
}
}
}
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Dropping the Environment.",);
let mut environment_guard = self.env.lock().expect("Failed to acquire lock: another thread panicked?");
// NOTE: If we drop an 'Environment' we (obviously) have _at least_
// one 'G_ENV' strong count (the one in the 'env' member).
// There is also the "original" 'G_ENV' which is a the lazy_static global.
// If there is no other environment, the strong count should be two and we
// can properly free the sys::OrtEnv pointer.
if Arc::strong_count(&G_ENV) == 2 {
let release_env = ort().ReleaseEnv.unwrap();
let env_ptr: *mut sys::OrtEnv = *environment_guard.env_ptr.get_mut();
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Releasing the Environment.",);
assert_ne!(env_ptr, std::ptr::null_mut());
if env_ptr.is_null() {
error!("Environment pointer is null, not dropping!");
} else {
unsafe { release_env(env_ptr) };
}
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
environment_guard.name = String::from("uninitialized");
}
}
}
/// Struct used to build an environment [`Environment`].
///
/// This is ONNX Runtime's main entry point. An environment _must_ be created as the first step. An [`Environment`] can
/// only be built using `EnvBuilder` to configure it.
///
/// Libraries using `ort` should **not** create an environment, as only one is allowed per process. Instead, allow the
/// user to pass their own environment to the library.
///
/// **NOTE**: If the same configuration method (for example [`EnvBuilder::with_name()`] is called multiple times, the
/// last value will have precedence.
pub struct EnvBuilder {
name: String,
log_level: LoggingLevel,
execution_providers: Vec<ExecutionProvider>
}
impl EnvBuilder {
/// Configure the environment with a given name
///
/// **NOTE**: Since ONNX can only define one environment per process, creating multiple environments using multiple
/// [`EnvBuilder`]s will end up re-using the same environment internally; a new one will _not_ be created. New
/// parameters will be ignored.
pub fn with_name<S>(mut self, name: S) -> EnvBuilder
where
S: Into<String>
{
self.name = name.into();
self
}
/// Configure the environment with a given log level
///
/// **NOTE**: Since ONNX can only define one environment per process, creating multiple environments using multiple
/// [`EnvBuilder`]s will end up re-using the same environment internally; a new one will _not_ be created. New
/// parameters will be ignored.
pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder {
self.log_level = log_level;
self
}
/// Configures a list of execution providers sessions created under this environment will use by default. Sessions
/// may override these via [`SessionBuilder::with_execution_providers()`].
///
/// Execution providers are loaded in the order they are provided until a suitable execution provider is found. Most
/// execution providers will silently fail if they are unavailable or misconfigured (see notes below), however, some
/// may log to the console, which is sadly unavoidable. The CPU execution provider is always available, so always
/// put it last in the list (though it is not required).
///
/// Execution providers will only work if the corresponding `onnxep-*` feature is enabled and ONNX Runtime was built
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
/// feature enabled are currently ignored.
///
/// Execution provider options can be specified in the second argument. Refer to ONNX Runtime's
/// [execution provider docs](https://onnxruntime.ai/docs/execution-providers/) for configuration options. In most
/// cases, passing `None` to configure with no options is suitable.
///
/// It is recommended to enable the `cuda` EP for x86 platforms and the `acl` EP for ARM platforms for the best
/// performance, though this does mean you'll have to build ONNX Runtime for these targets. Microsoft's prebuilt
/// binaries are built with CUDA and TensorRT support, if you built `ort` with the `onnxep-cuda` or
/// `onnxep-tensorrt` features enabled.
///
/// Supported execution providers:
/// - `cpu`: Default CPU/MLAS execution provider. Available on all platforms.
/// - `acl`: Arm Compute Library
/// - `cuda`: NVIDIA CUDA/cuDNN
/// - `tensorrt`: NVIDIA TensorRT
///
/// ## Notes
///
/// - Using the CUDA/TensorRT execution providers **can terminate the process if the CUDA/TensorRT installation is
/// misconfigured**. Configuring the execution provider will seem to work, but when you attempt to run a session,
/// it will hard crash the process with a "stack buffer overrun" error. This can occur when CUDA/TensorRT is
/// missing a DLL such as `zlibwapi.dll`. To prevent your app from crashing, you can check to see if you can load
/// `zlibwapi.dll` before enabling the CUDA/TensorRT execution providers.
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProvider]>) -> EnvBuilder {
self.execution_providers = execution_providers.as_ref().to_vec();
self
}
/// Commit the configuration to a new [`Environment`].
pub fn build(self) -> OrtResult<Environment> {
Environment::new(self.name, self.log_level, self.execution_providers)
}
}
#[cfg(test)]
mod tests {
use std::sync::{RwLock, RwLockWriteGuard};
use test_log::test;
use super::*;
impl G_ENV {
fn is_initialized(&self) -> bool {
Arc::strong_count(self) >= 2
}
fn env_ptr(&self) -> *const sys::OrtEnv {
*self.lock().unwrap().env_ptr.get_mut()
}
}
struct ConcurrentTestRun {
lock: Arc<RwLock<()>>
}
lazy_static! {
static ref CONCURRENT_TEST_RUN: ConcurrentTestRun = ConcurrentTestRun { lock: Arc::new(RwLock::new(())) };
}
impl CONCURRENT_TEST_RUN {
fn single_test_run(&self) -> RwLockWriteGuard<()> {
self.lock.write().unwrap()
}
}
#[test]
fn env_is_initialized() {
let _run_lock = CONCURRENT_TEST_RUN.single_test_run();
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
let env = Environment::builder()
.with_name("env_is_initialized")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert!(G_ENV.is_initialized());
assert_ne!(G_ENV.env_ptr(), std::ptr::null_mut());
std::mem::drop(env);
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
}
#[ignore]
#[test]
fn sequential_environment_creation() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let mut prev_env_ptr = G_ENV.env_ptr();
for i in 0..10 {
let name = format!("sequential_environment_creation: {}", i);
let env = Environment::builder()
.with_name(name.clone())
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let next_env_ptr = G_ENV.env_ptr();
assert_ne!(next_env_ptr, prev_env_ptr);
prev_env_ptr = next_env_ptr;
assert_eq!(env.name(), name);
}
}
#[test]
fn concurrent_environment_creations() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let initial_name = String::from("concurrent_environment_creation");
let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning, Vec::new()).unwrap();
let main_env_ptr = main_env.env_ptr() as usize;
assert_eq!(main_env.name(), initial_name);
assert_eq!(main_env.env_ptr() as usize, main_env_ptr);
assert!(
(0..10)
.map(|t| {
let initial_name_cloned = initial_name.clone();
std::thread::spawn(move || {
let name = format!("concurrent_environment_creation: {}", t);
let env = Environment::builder()
.with_name(name)
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert_eq!(env.name(), initial_name_cloned);
assert_eq!(env.env_ptr() as usize, main_env_ptr);
})
})
.map(|child| child.join())
.all(|r| std::result::Result::is_ok(&r))
);
}
}

238
src/error.rs Normal file
View File

@@ -0,0 +1,238 @@
use std::{io, path::PathBuf, string};
use thiserror::Error;
use super::{char_p_to_string, ort, sys, tensor::TensorElementDataType};
/// Type alias for the Result type returned by ORT functions.
pub type OrtResult<T> = std::result::Result<T, OrtError>;
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtError {
#[error("Failed to construct Rust String")]
FfiStringConversion(OrtApiError),
/// An error occurred while creating an ONNX environment.
#[error("Failed to create ONNX Runtime environment: {0}")]
CreateEnvironment(OrtApiError),
/// Error occurred when creating ONNX session options.
#[error("Failed to create ONNX Runtime session options: {0}")]
CreateSessionOptions(OrtApiError),
/// Error occurred when creating an ONNX session.
#[error("Failed to create ONNX Runtime session: {0}")]
CreateSession(OrtApiError),
/// Error occurred when creating an ONNX allocator.
#[error("Failed to get ONNX allocator: {0}")]
GetAllocator(OrtApiError),
/// Error occurred when counting ONNX session input/output count.
#[error("Failed to get input or output count: {0}")]
GetInOutCount(OrtApiError),
/// Error occurred when getting ONNX input name.
#[error("Failed to get input name: {0}")]
GetInputName(OrtApiError),
/// Error occurred when getting ONNX type information
#[error("Failed to get type info: {0}")]
GetTypeInfo(OrtApiError),
/// Error occurred when casting ONNX type information to tensor information
#[error("Failed to cast type info to tensor info: {0}")]
CastTypeInfoToTensorInfo(OrtApiError),
/// Error occurred when getting tensor elements type
#[error("Failed to get tensor element type: {0}")]
GetTensorElementType(OrtApiError),
/// Error occurred when getting ONNX dimensions count
#[error("Failed to get dimensions count: {0}")]
GetDimensionsCount(OrtApiError),
/// Error occurred when getting ONNX dimensions
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
/// Error occurred when getting string length
#[error("Failed to get string tensor length: {0}")]
GetStringTensorDataLength(OrtApiError),
/// Error occurred when getting tensor element count
#[error("Failed to get tensor element count: {0}")]
GetTensorShapeElementCount(OrtApiError),
/// Error occurred when creating CPU memory information
#[error("Failed to create CPU memory info: {0}")]
CreateCpuMemoryInfo(OrtApiError),
/// Error occurred when creating ONNX tensor
#[error("Failed to create tensor: {0}")]
CreateTensor(OrtApiError),
/// Error occurred when creating ONNX tensor with specific data
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(OrtApiError),
/// Error occurred when filling a tensor with string data
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(OrtApiError),
/// Error occurred when checking if ONNX tensor was properly initialized
#[error("Failed to check if tensor is a tensor or was properly initialized: {0}")]
FailedTensorCheck(OrtApiError),
/// Error occurred when getting tensor type and shape
#[error("Failed to get tensor type and shape: {0}")]
GetTensorTypeAndShape(OrtApiError),
/// Error occurred when ONNX inference operation was called
#[error("Failed to run inference on model: {0}")]
SessionRun(OrtApiError),
/// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`.
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
/// Error occurred when extracting string data from an ONNX tensor
#[error("Failed to get tensor string data: {0}")]
GetStringTensorContent(OrtApiError),
/// Error occurred when converting data to a String
#[error("Data was not UTF-8: {0}")]
StringFromUtf8Error(#[from] string::FromUtf8Error),
/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models).
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] OrtDownloadError),
/// Dimensions of input data and the ONNX model do not match.
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
/// File does not exist
#[error("File `{filename:?}` does not exist")]
FileDoesNotExist {
/// Path which does not exists
filename: PathBuf
},
/// Path is invalid UTF-8
#[error("Path `{path:?}` cannot be converted to UTF-8")]
NonUtf8Path {
/// Path with invalid UTF-8
path: PathBuf
},
/// Attempt to build a Rust `CString` from a null pointer
#[error("Failed to build CString when original contains null: {0}")]
FfiStringNull(#[from] std::ffi::NulError),
#[error("{0} pointer should be null")]
/// ORT pointer should have been null
PointerShouldBeNull(String),
/// ORT pointer should not have been null
#[error("{0} pointer should not be null")]
PointerShouldNotBeNull(String),
/// The runtime type was undefined.
#[error("Undefined tensor element type")]
UndefinedTensorElementType,
/// Could not retrieve model metadata.
#[error("Failed to retrieve model metadata: {0}")]
GetModelMetadata(OrtApiError),
/// The user tried to extract the wrong type of tensor from the underlying data
#[error("Data type mismatch: was {:?}, tried to convert to {:?}", actual, requested)]
DataTypeMismatch {
/// The actual type of the ort output
actual: TensorElementDataType,
/// The type corresponding to the attempted conversion into a Rust type, not equal to `actual`
requested: TensorElementDataType
}
}
/// Error used when the input dimensions defined in the model and passed from an inference call do not match.
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
/// Number of inputs from model does not match the number of inputs from inference call.
#[error(
"Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})"
)]
InputsCount {
/// Number of input dimensions used by inference call
inference_input_count: usize,
/// Number of input dimensions defined in model
model_input_count: usize,
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>
},
/// Inputs length from model does not match the expected input from inference call
#[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")]
InputsLength {
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>
}
}
/// Error details when ONNX C API returns an error.
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtApiError {
/// Details about the error.
#[error("{0}")]
Msg(String),
/// Converting the ONNX error message to UTF-8 failed.
#[error("an error occurred, but ort failed to convert the error message to UTF-8")]
IntoStringError(std::ffi::IntoStringError)
}
/// Error from downloading pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models).
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtDownloadError {
/// Generic input/output error
#[error("Error reading file: {0}")]
IoError(#[from] io::Error),
/// Download error by ureq
#[cfg(feature = "fetch-models")]
#[error("Error downloading to file: {0}")]
FetchError(#[from] Box<ureq::Error>),
/// Error getting Content-Length from HTTP GET request.
#[error("Error getting Content-Length from HTTP GET")]
ContentLengthError,
/// Mismatch between amount of downloaded and expected bytes.
#[error("Error copying data to file: expected {expected} length, but got {io}")]
CopyError {
/// Expected amount of bytes to download
expected: u64,
/// Number of bytes read from network and written to file
io: u64
}
}
/// Wrapper type around ONNX's `OrtStatus` pointer.
///
/// This wrapper exists to facilitate conversion from C raw pointers to Rust error types.
pub struct OrtStatusWrapper(*mut sys::OrtStatus);
impl From<*mut sys::OrtStatus> for OrtStatusWrapper {
fn from(status: *mut sys::OrtStatus) -> Self {
OrtStatusWrapper(status)
}
}
pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> OrtResult<()> {
ptr.is_null().then_some(()).ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
}
pub(crate) fn assert_non_null_pointer<T>(ptr: *const T, name: &str) -> OrtResult<()> {
(!ptr.is_null())
.then_some(())
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
}
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
fn from(status: OrtStatusWrapper) -> Self {
if status.0.is_null() {
Ok(())
} else {
let raw: *const std::os::raw::c_char = unsafe { ort().GetErrorMessage.unwrap()(status.0) };
match char_p_to_string(raw) {
Ok(msg) => Err(OrtApiError::Msg(msg)),
Err(err) => match err {
OrtError::FfiStringConversion(OrtApiError::IntoStringError(e)) => Err(OrtApiError::IntoStringError(e)),
_ => unreachable!()
}
}
}
}
}
impl Drop for OrtStatusWrapper {
fn drop(&mut self) {
unsafe { ort().ReleaseStatus.unwrap()(self.0) }
}
}
pub(crate) fn status_to_result(status: *mut sys::OrtStatus) -> std::result::Result<(), OrtApiError> {
let status_wrapper: OrtStatusWrapper = status.into();
status_wrapper.into()
}

179
src/execution_providers.rs Normal file
View File

@@ -0,0 +1,179 @@
#![allow(unused_imports)]
use std::{collections::HashMap, ffi::CString, os::raw::c_char};
use super::{error::status_to_result, ortsys, sys};
extern "C" {
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_CPU(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
#[cfg(feature = "acl")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
}
#[derive(Debug, Clone)]
pub struct ExecutionProvider {
provider: String,
options: HashMap<String, String>
}
macro_rules! ep_providers {
($($fn_name:ident = $name:expr),*) => {
$(
/// Creates a new `
#[doc = $name]
#[doc = "` configuration object."]
pub fn $fn_name() -> Self {
Self::new($name)
}
)*
}
}
macro_rules! ep_if_available {
($($fn_name:ident($original:ident): $name:expr),*) => {
$(
/// Creates a new
#[doc = $name]
#[doc = " execution provider if available, otherwise falling back to CPU."]
pub fn $fn_name() -> Self {
let o = Self::$original();
if o.is_available() { o } else { Self::cpu() }
}
)*
}
}
macro_rules! ep_options {
($(
$(#[$meta:meta])*
pub fn $fn_name:ident($opt_type:ty) = $option_name:ident;
)*) => {
$(
$(#[$meta])*
pub fn $fn_name(mut self, v: $opt_type) -> Self {
self = self.with(stringify!($option_name), v.to_string());
self
}
)*
}
}
impl ExecutionProvider {
pub fn new(provider: impl Into<String>) -> Self {
Self {
provider: provider.into(),
options: HashMap::new()
}
}
ep_providers! {
acl = "AclExecutionProvider",
cuda = "CUDAExecutionProvider",
tensorrt = "TensorRTExecutionProvider",
cpu = "CPUExecutionProvider"
}
pub fn is_available(&self) -> bool {
let mut providers: *mut *mut c_char = std::ptr::null_mut();
let mut num_providers = 0;
if status_to_result(ortsys![unsafe GetAvailableProviders(&mut providers, &mut num_providers)]).is_err() {
return false;
}
for i in 0..num_providers {
let avail = unsafe { std::ffi::CStr::from_ptr(*providers.offset(i as isize)) }
.to_string_lossy()
.into_owned();
if self.provider == avail {
return true;
}
}
false
}
ep_if_available! {
tensorrt_if_available(tensorrt): "TensorRT",
cuda_if_available(cuda): "CUDA",
acl_if_available(acl): "ACL"
}
/// Configure this execution provider with the given option.
pub fn with(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
self.options.insert(k.into(), v.into());
self
}
ep_options! {
/// Whether or not to use CPU arena allocator.
pub fn with_use_arena(bool) = use_arena;
}
}
pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, execution_providers: impl AsRef<[ExecutionProvider]>) {
for ep in execution_providers.as_ref() {
let init_args = ep.options.clone();
match ep.provider.as_str() {
#[cfg(feature = "acl")]
"AclExecutionProvider" => {
let use_arena = init_args.get("use_arena").map(|s| s.parse::<bool>().unwrap_or(false)).unwrap_or(false);
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_ACL(options, use_arena.into()) };
if status_to_result(status).is_ok() {
return; // EP found
}
}
"CPUExecutionProvider" => {
let use_arena = init_args.get("use_arena").map(|s| s.parse::<bool>().unwrap_or(false)).unwrap_or(false);
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_CPU(options, use_arena.into()) };
if status_to_result(status).is_ok() {
return; // EP found
}
}
#[cfg(feature = "cuda")]
"CUDAExecutionProvider" => {
let mut cuda_options: *mut sys::OrtCUDAProviderOptionsV2 = std::ptr::null_mut();
if status_to_result(ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)]).is_err() {
continue; // next EP
}
let keys: Vec<CString> = init_args.keys().map(|k| CString::new(k.as_str()).unwrap()).collect();
let values: Vec<CString> = init_args.values().map(|v| CString::new(v.as_str()).unwrap()).collect();
assert_eq!(keys.len(), values.len()); // sanity check
let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect();
let value_ptrs: Vec<*const c_char> = values.iter().map(|v| v.as_ptr()).collect();
let status = ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), keys.len())];
if status_to_result(status).is_err() {
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
continue; // next EP
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(options, cuda_options)];
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
if status_to_result(status).is_ok() {
return; // EP found
}
}
#[cfg(feature = "tensorrt")]
"TensorRTExecutionProvider" => {
let mut tensorrt_options: *mut sys::OrtTensorRTProviderOptionsV2 = std::ptr::null_mut();
if status_to_result(ortsys![unsafe CreateTensorRTProviderOptions(&mut tensorrt_options)]).is_err() {
continue; // next EP
}
let keys: Vec<CString> = init_args.keys().map(|k| CString::new(k.as_str()).unwrap()).collect();
let values: Vec<CString> = init_args.values().map(|v| CString::new(v.as_str()).unwrap()).collect();
assert_eq!(keys.len(), values.len()); // sanity check
let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect();
let value_ptrs: Vec<*const c_char> = values.iter().map(|v| v.as_ptr()).collect();
let status = ortsys![unsafe UpdateTensorRTProviderOptions(tensorrt_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), keys.len())];
if status_to_result(status).is_err() {
ortsys![unsafe ReleaseTensorRTProviderOptions(tensorrt_options)];
continue; // next EP
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_TensorRT_V2(options, tensorrt_options)];
ortsys![unsafe ReleaseTensorRTProviderOptions(tensorrt_options)];
if status_to_result(status).is_ok() {
return; // EP found
}
}
_ => {}
};
}
}

348
src/lib.rs Normal file
View File

@@ -0,0 +1,348 @@
pub mod download;
pub mod environment;
pub mod error;
pub mod execution_providers;
pub mod memory;
pub mod metadata;
pub mod session;
pub mod sys;
pub mod tensor;
use std::{
ffi::{self, CStr},
os::raw::c_char,
ptr,
sync::{atomic::AtomicPtr, Arc, Mutex}
};
pub use environment::Environment;
pub use error::{OrtApiError, OrtError, OrtResult};
pub use execution_providers::ExecutionProvider;
use lazy_static::lazy_static;
pub use session::{Session, SessionBuilder};
use self::sys::OnnxEnumInt;
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
}
pub(crate) use extern_system_fn;
lazy_static! {
pub(crate) static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const sys::OrtApi } = unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
};
}
pub fn ort() -> sys::OrtApi {
let mut api_ref = G_ORT_API.lock().expect("failed to acquire OrtApi lock; another thread panicked?");
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
assert_ne!(api_ptr_mut, ptr::null_mut());
unsafe { *api_ptr_mut }
}
macro_rules! ortsys {
($method:tt) => {
$crate::ort().$method.unwrap()
};
(unsafe $method:tt) => {
unsafe { $crate::ort().$method.unwrap() }
};
($method:tt($($n:expr),+ $(,)?)) => {
$crate::ort().$method.unwrap()($($n),+)
};
(unsafe $method:tt($($n:expr),+ $(,)?)) => {
unsafe { $crate::ort().$method.unwrap()($($n),+) }
};
($method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::ort().$method.unwrap()($($n),+);
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
unsafe { $crate::ort().$method.unwrap()($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
($method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
};
(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
};
($method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
}
macro_rules! ortfree {
(unsafe $allocator_ptr:expr, $ptr:tt) => {
unsafe { (*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void) }
};
($allocator_ptr:expr, $ptr:tt) => {
(*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void)
};
}
pub(crate) use ortfree;
pub(crate) use ortsys;
pub(crate) fn char_p_to_string(raw: *const c_char) -> OrtResult<String> {
let c_string = unsafe { CStr::from_ptr(raw as *mut c_char).to_owned() };
match c_string.into_string() {
Ok(string) => Ok(string),
Err(e) => Err(OrtApiError::IntoStringError(e))
}
.map_err(OrtError::FfiStringConversion)
}
/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct.
#[derive(Debug)]
struct CodeLocation<'a> {
file: &'a str,
line: &'a str,
function: &'a str
}
impl<'a> From<&'a str> for CodeLocation<'a> {
fn from(code_location: &'a str) -> Self {
let mut splitter = code_location.split(' ');
let file_and_line = splitter.next().unwrap_or("<unknown file>:<unknown line>");
let function = splitter.next().unwrap_or("<unknown function>");
let mut file_and_line_splitter = file_and_line.split(':');
let file = file_and_line_splitter.next().unwrap_or("<unknown file>");
let line = file_and_line_splitter.next().unwrap_or("<unknown line>");
CodeLocation { file, line, function }
}
}
extern_system_fn! {
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const c_char, log_id: *const c_char, code_location: *const c_char, message: *const c_char) {
use tracing::{span, Level, trace, debug, warn, info, error};
let log_level = match severity {
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING => Level::INFO,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR => Level::WARN,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
_ => Level::TRACE
};
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("unknown");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) };
assert_ne!(log_id, ptr::null());
let log_id = unsafe { CStr::from_ptr(log_id) };
let code_location = CodeLocation::from(code_location);
let span = span!(
Level::TRACE,
"ort",
category = category.to_str().unwrap_or("<unknown>"),
file = code_location.file,
line = code_location.line,
function = code_location.function,
log_id = log_id.to_str().unwrap_or("<unknown>")
);
let _enter = span.enter();
match log_level {
Level::TRACE => trace!("{:?}", message),
Level::DEBUG => debug!("{:?}", message),
Level::INFO => info!("{:?}", message),
Level::WARN => warn!("{:?}", message),
Level::ERROR => error!("{:?}", message)
}
}
}
/// ONNX Runtime logging level.
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum LoggingLevel {
Verbose = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
Info = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
Warning = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
Error = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
Fatal = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt
}
impl From<LoggingLevel> for sys::OrtLoggingLevel {
fn from(logging_level: LoggingLevel) -> Self {
match logging_level {
LoggingLevel::Verbose => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE,
LoggingLevel::Info => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO,
LoggingLevel::Warning => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING,
LoggingLevel::Error => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR,
LoggingLevel::Fatal => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL
}
}
}
/// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially
/// graph-level transformations, ranging from small graph simplifications and node eliminations to more complex node
/// fusions and layout optimizations.
///
/// Graph optimizations are divided in several categories (or levels) based on their complexity and functionality. They
/// can be performed either online or offline. In online mode, the optimizations are done before performing the
/// inference, while in offline mode, the runtime saves the optimized graph to disk (most commonly used when converting
/// an ONNX model to an ONNX Runtime model).
///
/// The optimizations belonging to one level are performed after the optimizations of the previous level have been
/// applied (e.g., extended optimizations are applied after basic optimizations have been applied).
///
/// **All optimizations are enabled by default.**
///
/// # Online/offline mode
/// All optimizations can be performed either online or offline. In online mode, when initializing an inference session,
/// we also apply all enabled graph optimizations before performing model inference. Applying all optimizations each
/// time we initiate a session can add overhead to the model startup time (especially for complex models), which can be
/// critical in production scenarios. This is where the offline mode can bring a lot of benefit. In offline mode, after
/// performing graph optimizations, ONNX Runtime serializes the resulting model to disk. Subsequently, we can reduce
/// startup time by using the already optimized model and disabling all optimizations.
///
/// ## Notes:
/// - When running in offline mode, make sure to use the exact same options (e.g., execution providers, optimization
/// level) and hardware as the target machine that the model inference will run on (e.g., you cannot run a model
/// pre-optimized for a GPU execution provider on a machine that is equipped only with CPU).
/// - When layout optimizations are enabled, the offline mode can only be used on compatible hardware to the environment
/// when the offline model is saved. For example, if model has layout optimized for AVX2, the offline model would
/// require CPUs that support AVX2.
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum GraphOptimizationLevel {
Disable = sys::GraphOptimizationLevel_ORT_DISABLE_ALL as OnnxEnumInt,
/// Level 1 includes semantics-preserving graph rewrites which remove redundant nodes and redundant computation.
/// They run before graph partitioning and thus apply to all the execution providers. Available basic/level 1 graph
/// optimizations are as follows:
///
/// - Constant Folding: Statically computes parts of the graph that rely only on constant initializers. This
/// eliminates the need to compute them during runtime.
/// - Redundant node eliminations: Remove all redundant nodes without changing the graph structure. The following
/// such optimizations are currently supported:
/// * Identity Elimination
/// * Slice Elimination
/// * Unsqueeze Elimination
/// * Dropout Elimination
/// - Semantics-preserving node fusions : Fuse/fold multiple nodes into a single node. For example, Conv Add fusion
/// folds the Add operator as the bias of the Conv operator. The following such optimizations are currently
/// supported:
/// * Conv Add Fusion
/// * Conv Mul Fusion
/// * Conv BatchNorm Fusion
/// * Relu Clip Fusion
/// * Reshape Fusion
Level1 = sys::GraphOptimizationLevel_ORT_ENABLE_BASIC as OnnxEnumInt,
#[rustfmt::skip]
/// Level 2 optimizations include complex node fusions. They are run after graph partitioning and are only applied to
/// the nodes assigned to the CPU or CUDA execution provider. Available extended/level 2 graph optimizations are as follows:
///
/// | Optimization | EPs | Comments |
/// |:------------------------------- |:--------- |:------------------------------------------------------------------------------ |
/// | GEMM Activation Fusion | CPU | |
/// | Matmul Add Fusion | CPU | |
/// | Conv Activation Fusion | CPU | |
/// | GELU Fusion | CPU, CUDA | |
/// | Layer Normalization Fusion | CPU, CUDA | |
/// | BERT Embedding Layer Fusion | CPU, CUDA | Fuses BERT embedding layers, layer normalization, & attention mask length |
/// | Attention Fusion* | CPU, CUDA | |
/// | Skip Layer Normalization Fusion | CPU, CUDA | Fuse bias of fully connected layers, skip connections, and layer normalization |
/// | Bias GELU Fusion | CPU, CUDA | Fuse bias of fully connected layers & GELU activation |
/// | GELU Approximation* | CUDA | Disabled by default; enable with `OrtSessionOptions::EnableGeluApproximation` |
///
/// > **NOTE**: To optimize performance of the BERT model, approximation is used in GELU Approximation and Attention
/// Fusion for the CUDA execution provider. The impact on accuracy is negligible based on our evaluation; F1 score
/// for a BERT model on SQuAD v1.1 is almost the same (87.05 vs 87.03).
Level2 = sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED as OnnxEnumInt,
/// Level 3 optimizations include memory layout optimizations, which may optimize the graph to use the NCHWc memory
/// layout rather than NCHW to improve spatial locality for some targets.
Level3 = sys::GraphOptimizationLevel_ORT_ENABLE_ALL as OnnxEnumInt
}
impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
fn from(val: GraphOptimizationLevel) -> Self {
match val {
GraphOptimizationLevel::Disable => sys::GraphOptimizationLevel_ORT_DISABLE_ALL,
GraphOptimizationLevel::Level1 => sys::GraphOptimizationLevel_ORT_ENABLE_BASIC,
GraphOptimizationLevel::Level2 => sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED,
GraphOptimizationLevel::Level3 => sys::GraphOptimizationLevel_ORT_ENABLE_ALL
}
}
}
/// Allocator type
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum AllocatorType {
/// Device allocator
Device = sys::OrtAllocatorType_OrtDeviceAllocator,
/// Arena allocator
Arena = sys::OrtAllocatorType_OrtArenaAllocator
}
impl From<AllocatorType> for sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
match val {
AllocatorType::Device => sys::OrtAllocatorType_OrtDeviceAllocator,
AllocatorType::Arena => sys::OrtAllocatorType_OrtArenaAllocator
}
}
}
/// Memory type
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum MemType {
CPUInput = sys::OrtMemType_OrtMemTypeCPUInput,
CPUOutput = sys::OrtMemType_OrtMemTypeCPUOutput,
/// Default memory type
Default = sys::OrtMemType_OrtMemTypeDefault
}
impl MemType {
pub const CPU: MemType = MemType::CPUOutput;
}
impl From<MemType> for sys::OrtMemType {
fn from(val: MemType) -> Self {
match val {
MemType::CPUInput => sys::OrtMemType_OrtMemTypeCPUInput,
MemType::CPUOutput => sys::OrtMemType_OrtMemTypeCPUOutput,
MemType::Default => sys::OrtMemType_OrtMemTypeDefault
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_char_p_to_string() {
let s = ffi::CString::new("foo").unwrap();
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).unwrap());
}
}

49
src/memory.rs Normal file
View File

@@ -0,0 +1,49 @@
use tracing::{debug, error};
use super::{error::OrtResult, ortsys, sys, AllocatorType, MemType};
#[derive(Debug)]
pub(crate) struct MemoryInfo {
pub ptr: *mut sys::OrtMemoryInfo
}
impl MemoryInfo {
#[tracing::instrument]
pub fn new(allocator: AllocatorType, memory_type: MemType) -> OrtResult<Self> {
debug!("Creating new OrtMemoryInfo.");
let mut memory_info_ptr: *mut sys::OrtMemoryInfo = std::ptr::null_mut();
ortsys![
unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr);
nonNull(memory_info_ptr)
];
Ok(Self { ptr: memory_info_ptr })
}
}
impl Drop for MemoryInfo {
#[tracing::instrument]
fn drop(&mut self) {
if self.ptr.is_null() {
error!("OrtMemoryInfo pointer is null, not dropping.");
} else {
debug!("Dropping OrtMemoryInfo");
ortsys![unsafe ReleaseMemoryInfo(self.ptr)];
}
self.ptr = std::ptr::null_mut();
}
}
#[cfg(test)]
mod tests {
use test_log::test;
use super::*;
#[test]
fn create_memory_info() {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
std::mem::drop(memory_info);
}
}

70
src/metadata.rs Normal file
View File

@@ -0,0 +1,70 @@
#![allow(clippy::tabs_in_doc_comments)]
use std::{ffi::CString, os::raw::c_char};
use super::{char_p_to_string, error::OrtResult, ortfree, ortsys, sys, OrtError};
pub struct Metadata {
metadata_ptr: *mut sys::OrtModelMetadata,
allocator_ptr: *mut sys::OrtAllocator
}
impl Metadata {
pub(crate) fn new(metadata_ptr: *mut sys::OrtModelMetadata, allocator_ptr: *mut sys::OrtAllocator) -> Self {
Metadata { metadata_ptr, allocator_ptr }
}
pub fn description(&self) -> OrtResult<String> {
let mut str_bytes: *mut c_char = std::ptr::null_mut();
ortsys![unsafe ModelMetadataGetDescription(self.metadata_ptr, self.allocator_ptr, &mut str_bytes) -> OrtError::GetModelMetadata; nonNull(str_bytes)];
let value = char_p_to_string(str_bytes)?;
ortfree!(unsafe self.allocator_ptr, str_bytes);
Ok(value)
}
pub fn producer(&self) -> OrtResult<String> {
let mut str_bytes: *mut c_char = std::ptr::null_mut();
ortsys![unsafe ModelMetadataGetProducerName(self.metadata_ptr, self.allocator_ptr, &mut str_bytes) -> OrtError::GetModelMetadata; nonNull(str_bytes)];
let value = char_p_to_string(str_bytes)?;
ortfree!(unsafe self.allocator_ptr, str_bytes);
Ok(value)
}
pub fn name(&self) -> OrtResult<String> {
let mut str_bytes: *mut c_char = std::ptr::null_mut();
ortsys![unsafe ModelMetadataGetGraphName(self.metadata_ptr, self.allocator_ptr, &mut str_bytes) -> OrtError::GetModelMetadata; nonNull(str_bytes)];
let value = char_p_to_string(str_bytes)?;
ortfree!(unsafe self.allocator_ptr, str_bytes);
Ok(value)
}
pub fn version(&self) -> OrtResult<i64> {
let mut ver = 0i64;
ortsys![unsafe ModelMetadataGetVersion(self.metadata_ptr, &mut ver) -> OrtError::GetModelMetadata];
Ok(ver)
}
pub fn custom(&self, key: &str) -> OrtResult<Option<String>> {
let mut str_bytes: *mut c_char = std::ptr::null_mut();
let key_str = CString::new(key)?;
ortsys![unsafe ModelMetadataLookupCustomMetadataMap(self.metadata_ptr, self.allocator_ptr, key_str.as_ptr(), &mut str_bytes) -> OrtError::GetModelMetadata];
if !str_bytes.is_null() {
unsafe {
let value = char_p_to_string(str_bytes)?;
ortfree!(self.allocator_ptr, str_bytes);
Ok(Some(value))
}
} else {
Ok(None)
}
}
}
impl Drop for Metadata {
fn drop(&mut self) {
ortsys![unsafe ReleaseModelMetadata(self.metadata_ptr)];
}
}

838
src/session.rs Normal file
View File

@@ -0,0 +1,838 @@
#![allow(clippy::tabs_in_doc_comments)]
#[cfg(feature = "fetch-models")]
use std::env;
#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
use std::{
ffi::CString,
fmt::{self, Debug},
os::raw::c_char,
path::Path,
sync::Arc
};
use ndarray::IxDyn;
use tracing::{debug, error};
use super::{
char_p_to_string,
environment::Environment,
error::{assert_non_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError, OrtApiError, OrtError, OrtResult},
execution_providers::{apply_execution_providers, ExecutionProvider},
extern_system_fn,
memory::MemoryInfo,
metadata::Metadata,
ort, ortsys, sys,
tensor::{
type_dynamic_tensor::{InputOrtTensor, InputTensor},
DynOrtTensor, TensorElementDataType
},
AllocatorType, GraphOptimizationLevel, MemType
};
#[cfg(feature = "fetch-models")]
use super::{download::OnnxModel, error::OrtDownloadError};
/// Type used to create a session using the _builder pattern_.
///
/// A `SessionBuilder` is created by calling the [`Environment::session()`] method on the environment.
///
/// Once created, you can use the different methods to configure the session.
///
/// Once configured, use the [`SessionBuilder::with_model_from_file()`] method to "commit" the builder configuration
/// into a [`Session`].
///
/// # Example
///
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ort::{Environment, LoggingLevel, GraphOptimizationLevel, SessionBuilder};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// let environment = Arc::new(Environment::builder().with_name("test").with_log_level(LoggingLevel::Verbose).build()?);
/// let mut session = SessionBuilder::new(&environment)?
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
/// .with_intra_threads(1)?
/// .with_model_from_file("squeezenet.onnx")?;
/// # Ok(())
/// # }
/// ```
pub struct SessionBuilder {
env: Arc<Environment>,
session_options_ptr: *mut sys::OrtSessionOptions,
allocator: AllocatorType,
memory_type: MemType,
custom_runtime_handles: Vec<*mut std::os::raw::c_void>,
execution_providers: Vec<ExecutionProvider>
}
impl Debug for SessionBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("SessionBuilder")
.field("env", &self.env.name())
.field("allocator", &self.allocator)
.field("memory_type", &self.memory_type)
.finish()
}
}
impl Drop for SessionBuilder {
#[tracing::instrument]
fn drop(&mut self) {
for &handle in self.custom_runtime_handles.iter() {
close_lib_handle(handle);
}
if self.session_options_ptr.is_null() {
error!("Session options pointer is null, not dropping");
} else {
debug!("Dropping the session options.");
ortsys![unsafe ReleaseSessionOptions(self.session_options_ptr)];
}
}
}
impl SessionBuilder {
pub fn new(env: &Arc<Environment>) -> OrtResult<Self> {
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> OrtError::CreateSessionOptions; nonNull(session_options_ptr)];
Ok(Self {
env: Arc::clone(env),
session_options_ptr,
allocator: AllocatorType::Arena,
memory_type: MemType::Default,
custom_runtime_handles: Vec::new(),
execution_providers: Vec::new()
})
}
/// Configures a list of execution providers to attempt to use for the session.
///
/// Execution providers are loaded in the order they are provided until a suitable execution provider is found. Most
/// execution providers will silently fail if they are unavailable or misconfigured (see notes below), however, some
/// may log to the console, which is sadly unavoidable. The CPU execution provider is always available, so always
/// put it last in the list (though it is not required).
///
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
/// feature enabled are currently ignored.
///
/// Execution provider options can be specified in the second argument. Refer to ONNX Runtime's
/// [execution provider docs](https://onnxruntime.ai/docs/execution-providers/) for configuration options. In most
/// cases, passing `None` to configure with no options is suitable.
///
/// It is recommended to enable the `cuda` EP for x86 platforms and the `acl` EP for ARM platforms for the best
/// performance, though this does mean you'll have to build ONNX Runtime for these targets. Microsoft's prebuilt
/// binaries are built with CUDA and TensorRT support, if you built `ort` with the `cuda` or `tensorrt` features
/// enabled.
///
/// Supported execution providers:
/// - `cpu`: Default CPU/MLAS execution provider. Available on all platforms.
/// - `acl`: Arm Compute Library
/// - `cuda`: NVIDIA CUDA/cuDNN
/// - `tensorrt`: NVIDIA TensorRT
///
/// ## Notes
///
/// - **Use of [`SessionBuilder::with_execution_providers`] in a library is discouraged.** Execution providers
/// should always be configurable by the user, in case an execution provider is misconfigured and causes the
/// application to crash (see notes below). Instead, your library should accept an [`Environment`] from the user
/// rather than creating its own. This way, the user can configure execution providers for **all** modules that
/// use it.
/// - Using the CUDA/TensorRT execution providers **can terminate the process if the CUDA/TensorRT installation is
/// misconfigured**. Configuring the execution provider will seem to work, but when you attempt to run a session,
/// it will hard crash the process with a "stack buffer overrun" error. This can occur when CUDA/TensorRT is
/// missing a DLL such as `zlibwapi.dll`. To prevent your app from crashing, you can check to see if you can load
/// `zlibwapi.dll` before enabling the CUDA/TensorRT execution providers.
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProvider]>) -> OrtResult<Self> {
self.execution_providers = execution_providers.as_ref().to_vec();
Ok(self)
}
/// Configure the session to use a number of threads to parallelize the execution within nodes. If ONNX Runtime was
/// built with OpenMP (as is the case with Microsoft's prebuilt binaries), this will have no effect on the number of
/// threads used. Instead, you can configure the number of threads OpenMP uses via the `OMP_NUM_THREADS` environment
/// variable.
///
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
/// [`SessionBuilder::with_inter_threads()`].
pub fn with_intra_threads(self, num_threads: i16) -> OrtResult<Self> {
// We use a u16 in the builder to cover the 16-bits positive values of a i32.
let num_threads = num_threads as i32;
ortsys![unsafe SetIntraOpNumThreads(self.session_options_ptr, num_threads) -> OrtError::CreateSessionOptions];
Ok(self)
}
/// Configure the session to use a number of threads to parallelize the execution of the graph. If nodes can be run
/// in parallel, this sets the maximum number of threads to use to run them in parallel.
///
/// This has no effect when the session execution mode is set to `Sequential`.
///
/// For configuring the number of threads used to parallelize the execution within nodes, see
/// [`SessionBuilder::with_intra_threads()`].
pub fn with_inter_threads(self, num_threads: i16) -> OrtResult<Self> {
// We use a u16 in the builder to cover the 16-bits positive values of a i32.
let num_threads = num_threads as i32;
ortsys![unsafe SetInterOpNumThreads(self.session_options_ptr, num_threads) -> OrtError::CreateSessionOptions];
Ok(self)
}
/// Enable/disable the parallel execution mode for this session. By default, this is disabled.
///
/// Parallel execution can improve performance for models with many branches, at the cost of higher memory usage.
/// You can configure the amount of threads used to parallelize the execution of the graph via
/// [`SessionBuilder::with_inter_threads()`].
pub fn with_parallel_execution(self, parallel_execution: bool) -> OrtResult<Self> {
let execution_mode = if parallel_execution {
sys::ExecutionMode_ORT_PARALLEL
} else {
sys::ExecutionMode_ORT_SEQUENTIAL
};
ortsys![unsafe SetSessionExecutionMode(self.session_options_ptr, execution_mode) -> OrtError::CreateSessionOptions];
Ok(self)
}
/// Set the session's optimization level. See [`GraphOptimizationLevel`] for more information on the different
/// optimization levels.
pub fn with_optimization_level(self, opt_level: GraphOptimizationLevel) -> OrtResult<Self> {
ortsys![unsafe SetSessionGraphOptimizationLevel(self.session_options_ptr, opt_level.into()) -> OrtError::CreateSessionOptions];
Ok(self)
}
/// Set the session's allocator. Defaults to [`AllocatorType::Arena`].
pub fn with_allocator(mut self, allocator: AllocatorType) -> OrtResult<Self> {
self.allocator = allocator;
Ok(self)
}
/// Set the session's memory type. Defaults to [`MemType::Default`].
pub fn with_memory_type(mut self, memory_type: MemType) -> OrtResult<Self> {
self.memory_type = memory_type;
Ok(self)
}
/// Registers a custom operator library with the given library path in the session.
pub fn with_custom_op_lib(mut self, lib_path: &str) -> OrtResult<Self> {
let path_cstr = CString::new(lib_path)?;
let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut();
let status = ortsys![unsafe RegisterCustomOpsLibrary(self.session_options_ptr, path_cstr.as_ptr(), &mut handle)];
// per RegisterCustomOpsLibrary docs, release handle if there was an error and the handle
// is non-null
match status_to_result(status).map_err(OrtError::CreateSessionOptions) {
Ok(_) => {}
Err(e) => {
if !handle.is_null() {
// handle was written to, should release it
close_lib_handle(handle);
}
return Err(e);
}
}
self.custom_runtime_handles.push(handle);
Ok(self)
}
/// Downloads a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) and builds the session.
#[cfg(feature = "fetch-models")]
pub fn with_model_downloaded<M>(self, model: M) -> OrtResult<Session>
where
M: Into<OnnxModel>
{
self.with_model_downloaded_monomorphized(model.into())
}
#[cfg(feature = "fetch-models")]
fn with_model_downloaded_monomorphized(self, model: OnnxModel) -> OrtResult<Session> {
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
let downloaded_path = model.download_to(download_dir)?;
self.with_model_from_file(downloaded_path)
}
// TODO: Add all functions changing the options.
// See all OrtApi methods taking a `options: *mut OrtSessionOptions`.
/// Loads an ONNX model from a file and builds the session.
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> OrtResult<Session>
where
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(OrtError::FileDoesNotExist {
filename: model_filepath.to_path_buf()
});
}
// Build an OsString, then a vector of bytes to pass to C
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) // Make sure we have a null terminated string
.collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
.map(|b| *b as std::os::raw::c_char)
.collect();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession; nonNull(session_ptr)];
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator; nonNull(allocator_ptr)];
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
env: Arc::clone(&self.env),
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs
})
}
/// Load an ONNX graph from memory and commit the session
pub fn with_model_from_memory<B>(self, model_bytes: B) -> OrtResult<Session>
where
B: AsRef<[u8]>
{
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
}
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> OrtResult<Session> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len();
ortsys![
unsafe CreateSessionFromArray(env_ptr, model_data, model_data_length as _, self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession;
nonNull(session_ptr)
];
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator; nonNull(allocator_ptr)];
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
env: Arc::clone(&self.env),
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs
})
}
}
/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
#[derive(Debug)]
pub struct Session {
#[allow(dead_code)]
env: Arc<Environment>,
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
memory_info: MemoryInfo,
/// Information about the ONNX's inputs as stored in loaded file
pub inputs: Vec<Input>,
/// Information about the ONNX's outputs as stored in loaded file
pub outputs: Vec<Output>
}
/// Information about an ONNX's input as stored in loaded file
#[derive(Debug)]
pub struct Input {
/// Name of the input layer
pub name: String,
/// Type of the input layer's elements
pub input_type: TensorElementDataType,
/// Shape of the input layer
///
/// C API uses a i64 for the dimensions. We use an unsigned of the same range of the positive values.
pub dimensions: Vec<Option<u32>>
}
/// Information about an ONNX's output as stored in loaded file
#[derive(Debug)]
pub struct Output {
/// Name of the output layer
pub name: String,
/// Type of the output layer's elements
pub output_type: TensorElementDataType,
/// Shape of the output layer
///
/// C API uses a i64 for the dimensions. We use an unsigned of the same range of the positive values.
pub dimensions: Vec<Option<u32>>
}
impl Input {
/// Return an iterator over the shape elements of the input layer
///
/// Note: The member [`Input::dimensions`](struct.Input.html#structfield.dimensions)
/// stores `u32` (since ONNX uses `i64` but which cannot be negative) so the
/// iterator converts to `usize`.
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Output {
/// Return an iterator over the shape elements of the output layer
///
/// Note: The member [`Output::dimensions`](struct.Output.html#structfield.dimensions)
/// stores `u32` (since ONNX uses `i64` but which cannot be negative) so the
/// iterator converts to `usize`.
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Drop for Session {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session.");
if self.session_ptr.is_null() {
error!("Session pointer is null, not dropping.");
} else {
ortsys![unsafe ReleaseSession(self.session_ptr)];
}
self.session_ptr = std::ptr::null_mut();
self.allocator_ptr = std::ptr::null_mut();
}
}
impl Session {
/// Run the input data through the ONNX graph, performing inference.
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
/// used for the input data here.
pub fn run<'s, 'm>(&'s mut self, input_arrays: impl AsRef<[InputTensor]>) -> OrtResult<Vec<DynOrtTensor<'m, IxDyn>>>
where
's: 'm // 's outlives 'm (session outlives memory info)
{
let input_arrays = input_arrays.as_ref();
self.validate_input_shapes(input_arrays)?;
// Build arguments to Run()
let input_names_ptr: Vec<*const c_char> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const c_char)
.collect();
let output_names_cstring: Vec<CString> = self
.outputs
.iter()
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const c_char> = output_names_cstring.iter().map(|n| n.as_ptr() as *const c_char).collect();
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_tensors: Vec<InputOrtTensor> = input_arrays
.iter()
.map(|input_tensor| InputOrtTensor::from_input_tensor(&self.memory_info, self.allocator_ptr, input_tensor))
.collect::<OrtResult<Vec<InputOrtTensor>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors.iter().map(|input_array_ort| input_array_ort.c_ptr()).collect();
let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
ortsys![
unsafe Run(
self.session_ptr,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len() as _,
output_names_ptr.as_ptr(),
output_names_ptr.len() as _,
output_tensor_ptrs.as_mut_ptr()
) -> OrtError::SessionRun
];
let memory_info_ref = &self.memory_info;
let outputs: OrtResult<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| {
let (dims, data_type, len) = unsafe {
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
get_tensor_dimensions(tensor_info_ptr)
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
.and_then(|dims| extract_data_type(tensor_info_ptr).map(|data_type| (dims, data_type)))
.and_then(|(dims, data_type)| {
let mut len = 0;
ortsys![GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> OrtError::GetTensorShapeElementCount];
Ok((dims, data_type, len))
})
})
}?;
Ok(DynOrtTensor::new(tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), len as _, data_type))
})
.collect();
// Reconvert to CString so drop impl is called and memory is freed
let cstrings: OrtResult<Vec<CString>> = input_names_ptr
.into_iter()
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p as *mut c_char)) }
})
.collect();
cstrings?;
outputs
}
fn validate_input_shapes(&mut self, input_arrays: impl AsRef<[InputTensor]>) -> OrtResult<()> {
// ******************************************************************
// FIXME: Properly handle errors here
// Make sure all dimensions match (except dynamic ones)
let input_arrays = input_arrays.as_ref();
// Verify length of inputs
if input_arrays.len() != self.inputs.len() {
error!("Non-matching number of inputs: {} (inference) vs {} (model)", input_arrays.len(), self.inputs.len());
return Err(OrtError::NonMatchingDimensions(NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
inference_input: input_arrays.iter().map(|input_array| input_array.shape().to_vec()).collect(),
model_input: self.inputs.iter().map(|input| input.dimensions.clone()).collect()
}));
}
// Verify length of each individual inputs
let inputs_different_length = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| match l {
InputTensor::FloatTensor(input) => input.shape().len() != r.dimensions.len(),
#[cfg(feature = "half")]
InputTensor::Float16Tensor(input) => input.shape().len() != r.dimensions.len(),
#[cfg(feature = "half")]
InputTensor::Bfloat16Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint8Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int8Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint16Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int16Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int32Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int64Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::DoubleTensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint32Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint64Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::StringTensor(input) => input.shape().len() != r.dimensions.len()
});
if inputs_different_length {
error!("Different input lengths: {:?} vs {:?}", self.inputs, input_arrays);
return Err(OrtError::NonMatchingDimensions(NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays.iter().map(|input_array| input_array.shape().to_vec()).collect(),
model_input: self.inputs.iter().map(|input| input.dimensions.clone()).collect()
}));
}
// Verify shape of each individual inputs
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false // None means dynamic size; in that case shape always match
})
});
if inputs_different_shape {
error!("Different input lengths: {:?} vs {:?}", self.inputs, input_arrays);
return Err(OrtError::NonMatchingDimensions(NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays.iter().map(|input_array| input_array.shape().to_vec()).collect(),
model_input: self.inputs.iter().map(|input| input.dimensions.clone()).collect()
}));
}
Ok(())
}
pub fn metadata(&self) -> OrtResult<Metadata> {
let mut metadata_ptr: *mut sys::OrtModelMetadata = std::ptr::null_mut();
ortsys![unsafe SessionGetModelMetadata(self.session_ptr, &mut metadata_ptr) -> OrtError::GetModelMetadata; nonNull(metadata_ptr)];
Ok(Metadata::new(metadata_ptr, self.allocator_ptr))
}
}
unsafe fn get_tensor_dimensions(tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<Vec<i64>> {
let mut num_dims = 0;
ortsys![GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> OrtError::GetDimensionsCount];
assert_ne!(num_dims, 0);
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims) -> OrtError::GetDimensions];
Ok(node_dims)
}
unsafe fn extract_data_type(tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<TensorElementDataType> {
let mut type_sys = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![GetTensorElementType(tensor_info_ptr, &mut type_sys) -> OrtError::GetTensorElementType];
assert_ne!(type_sys, sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
// This transmute should be safe since its value is read from GetTensorElementType, which we must trust
Ok(std::mem::transmute(type_sys))
}
/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the
/// resulting `*OrtTensorTypeAndShapeInfo` before returning.
unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> OrtResult<T>
where
F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<T>
{
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![GetTensorTypeAndShape(tensor_ptr, &mut tensor_info_ptr) -> OrtError::GetTensorTypeAndShape];
let res = f(tensor_info_ptr);
ortsys![ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
}
#[cfg(unix)]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { libc::dlclose(handle) };
}
#[cfg(windows)]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE) };
}
/// This module contains dangerous functions working on raw pointers.
/// Those functions are only to be used from inside the
/// `SessionBuilder::with_model_from_file()` method.
mod dangerous {
use super::*;
use crate::{ortfree, tensor::TensorElementDataType};
pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> OrtResult<usize> {
let f = ort().SessionGetInputCount.unwrap();
extract_io_count(f, session_ptr)
}
pub(super) fn extract_outputs_count(session_ptr: *mut sys::OrtSession) -> OrtResult<usize> {
let f = ort().SessionGetOutputCount.unwrap();
extract_io_count(f, session_ptr)
}
#[cfg(target_arch = "x86_64")]
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut usize) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession
) -> OrtResult<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::GetInOutCount)?;
assert_null_pointer(status, "SessionStatus")?;
(num_nodes != 0)
.then_some(())
.ok_or_else(|| OrtError::GetInOutCount(OrtApiError::Msg("No nodes in model".to_owned())))?;
Ok(num_nodes)
}
#[cfg(target_arch = "aarch64")]
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut u64) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession
) -> OrtResult<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::GetInOutCount)?;
assert_null_pointer(status, "SessionStatus")?;
(num_nodes != 0)
.then_some(())
.ok_or_else(|| OrtError::GetInOutCount(OrtApiError::Msg("No nodes in model".to_owned())))?;
Ok(num_nodes as _)
}
fn extract_input_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<String> {
let f = ort().SessionGetInputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
fn extract_output_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<String> {
let f = ort().SessionGetOutputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
#[cfg(target_arch = "x86_64")]
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
usize,
*mut sys::OrtAllocator,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize
) -> OrtResult<String> {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::GetInputName)?;
assert_non_null_pointer(name_bytes, "InputName")?;
let name = char_p_to_string(name_bytes)?;
ortfree!(unsafe allocator_ptr, name_bytes);
Ok(name)
}
#[cfg(target_arch = "aarch64")]
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
u64,
*mut sys::OrtAllocator,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize
) -> OrtResult<String> {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i as _, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::GetInputName)?;
assert_non_null_pointer(name_bytes, "InputName")?;
let name = char_p_to_string(name_bytes)?;
ortfree!(unsafe allocator_ptr, name_bytes);
Ok(name)
}
pub(super) fn extract_input(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<Input> {
let input_name = extract_input_name(session_ptr, allocator_ptr, i)?;
let f = ort().SessionGetInputTypeInfo.unwrap();
let (input_type, dimensions) = extract_io(f, session_ptr, i)?;
Ok(Input {
name: input_name,
input_type,
dimensions
})
}
pub(super) fn extract_output(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<Output> {
let output_name = extract_output_name(session_ptr, allocator_ptr, i)?;
let f = ort().SessionGetOutputTypeInfo.unwrap();
let (output_type, dimensions) = extract_io(f, session_ptr, i)?;
Ok(Output {
name: output_name,
output_type,
dimensions
})
}
#[cfg(target_arch = "x86_64")]
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
usize,
*mut *mut sys::OrtTypeInfo,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
i: usize
) -> OrtResult<(TensorElementDataType, Vec<Option<u32>>)> {
let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut tensor_info_ptr) -> OrtError::CastTypeInfoToTensorInfo; nonNull(tensor_info_ptr)];
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)];
Ok((io_type, node_dims.into_iter().map(|d| if d == -1 { None } else { Some(d as u32) }).collect()))
}
#[cfg(target_arch = "aarch64")]
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
u64,
*mut *mut sys::OrtTypeInfo,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
i: usize
) -> OrtResult<(TensorElementDataType, Vec<Option<u32>>)> {
let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i as _, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut tensor_info_ptr) -> OrtError::CastTypeInfoToTensorInfo; nonNull(tensor_info_ptr)];
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)];
Ok((io_type, node_dims.into_iter().map(|d| if d == -1 { None } else { Some(d as u32) }).collect()))
}
}

15
src/sys.rs Normal file
View File

@@ -0,0 +1,15 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
// Disable clippy and `u128` not being FFI-safe
#![allow(clippy::all)]
#![allow(improper_ctypes)]
// bindgen-rs generates test code that dereferences null pointers
#![allow(deref_nullptr)]
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/bindings/bindings.rs"));
#[cfg(target_os = "windows")]
pub type OnnxEnumInt = i32;
#[cfg(not(target_os = "windows"))]
pub type OnnxEnumInt = u32;

309
src/tensor.rs Normal file
View File

@@ -0,0 +1,309 @@
//! Module containing tensor types.
//!
//! Two main types of tensors are available.
//!
//! The first one, [`OrtTensor`], is an _owned_ tensor that is backed by [`ndarray`](https://crates.io/crates/ndarray).
//! This kind of tensor is used to pass input data for the inference.
//!
//! The second one, [`OrtOwnedTensor`], is used internally to pass to the ONNX Runtime inference execution to place its
//! output values. Once "extracted" from the runtime environment, this tensor will contain an [`ndarray::ArrayView`]
//! containing _a view_ of the data. When going out of scope, this tensor will free the required memory on the C side.
//!
//! **NOTE**: Tensors are not meant to be created directly. When performing inference, the [`Session::run`] method takes
//! an `ndarray::Array` as input (taking ownership of it) and will convert it internally to an [`OrtTensor`]. After
//! inference, a [`OrtOwnedTensor`] will be returned by the method which can be derefed into its internal
//! [`ndarray::ArrayView`].
pub mod ndarray_tensor;
pub mod ort_owned_tensor;
pub mod ort_tensor;
pub mod type_dynamic_tensor;
use std::{ffi, fmt, ptr, rc, result, string};
pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor};
pub use ort_tensor::OrtTensor;
pub use type_dynamic_tensor::FromArray;
pub use type_dynamic_tensor::InputTensor;
use super::{
ortsys,
sys::{self as sys, OnnxEnumInt},
tensor::ort_owned_tensor::TensorPointerHolder,
OrtError, OrtResult
};
/// Enum mapping ONNX Runtime's supported tensor data types.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum TensorElementDataType {
/// 32-bit floating point number, equivalent to Rust's `f32`.
Float32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
/// Unsigned 8-bit integer, equivalent to Rust's `u8`.
Uint8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
/// Signed 8-bit integer, equivalent to Rust's `i8`.
Int8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
/// Unsigned 16-bit integer, equivalent to Rust's `u16`.
Uint16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
/// Signed 16-bit integer, equivalent to Rust's `i16`.
Int16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
/// Signed 32-bit integer, equivalent to Rust's `i32`.
Int32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
/// Signed 64-bit integer, equivalent to Rust's `i64`.
Int64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
/// String, equivalent to Rust's `String`.
String = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
/// Boolean, equivalent to Rust's `bool`.
Bool = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
#[cfg(feature = "half")]
/// 16-bit floating point number, equivalent to `half::f16` (requires the `half` crate).
Float16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
/// 64-bit floating point number, equivalent to Rust's `f64`. Also known as `double`.
Float64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
/// Unsigned 32-bit integer, equivalent to Rust's `u32`.
Uint32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
/// Unsigned 64-bit integer, equivalent to Rust's `u64`.
Uint64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
// /// Complex 64-bit floating point number, equivalent to Rust's `num_complex::Complex<f64>`.
// Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
// TODO: `num_complex` crate doesn't support i128 provided by the `decimal` crate.
// /// Complex 128-bit floating point number, equivalent to Rust's `num_complex::Complex<f128>`.
// Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
/// Brain 16-bit floating point number, equivalent to `half::bf16` (requires the `half` crate).
#[cfg(feature = "half")]
Bfloat16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt
}
impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
fn from(val: TensorElementDataType) -> Self {
match val {
TensorElementDataType::Float32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
TensorElementDataType::Uint8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
TensorElementDataType::Int8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
TensorElementDataType::Uint16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
TensorElementDataType::Int16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
TensorElementDataType::Int32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
TensorElementDataType::Int64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
TensorElementDataType::String => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
TensorElementDataType::Bool => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
#[cfg(feature = "half")]
TensorElementDataType::Float16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
TensorElementDataType::Float64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
TensorElementDataType::Uint32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
TensorElementDataType::Uint64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
// TensorElementDataType::Complex64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64,
// TensorElementDataType::Complex128 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128,
#[cfg(feature = "half")]
TensorElementDataType::Bfloat16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
}
}
}
/// Trait used to map Rust types (for example `f32`) to ONNX tensor element data types (for example `Float`).
pub trait IntoTensorElementDataType {
/// Returns the ONNX tensor element data type corresponding to the given Rust type.
fn tensor_element_data_type() -> TensorElementDataType;
/// If the type is `String`, returns `Some` with UTF-8 contents, else `None`.
fn try_utf8_bytes(&self) -> Option<&[u8]>;
}
macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl IntoTensorElementDataType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
}
fn try_utf8_bytes(&self) -> Option<&[u8]> {
None
}
}
};
}
impl_type_trait!(f32, Float32);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
impl_type_trait!(bool, Bool);
#[cfg(feature = "half")]
impl_type_trait!(half::f16, Float16);
impl_type_trait!(f64, Float64);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
// impl_type_trait!(num_complex::Complex<f64>, Complex64);
// impl_type_trait!(num_complex::Complex<f128>, Complex128);
#[cfg(feature = "half")]
impl_type_trait!(half::bf16, Bfloat16);
/// Adapter for common Rust string types to ONNX strings.
///
/// It should be easy to use both [`String`] and `&str` as [TensorElementDataType::String] data, but
/// we can't define an automatic implementation for anything that implements [`AsRef<str>`] as it
/// would conflict with the implementations of [IntoTensorElementDataType] for primitive numeric
/// types (which might implement [`AsRef<str>`] at some point in the future).
pub trait Utf8Data {
fn utf8_bytes(&self) -> &[u8];
}
impl Utf8Data for String {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl<'a> Utf8Data for &'a str {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl<T: Utf8Data> IntoTensorElementDataType for T {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}
fn try_utf8_bytes(&self) -> Option<&[u8]> {
Some(self.utf8_bytes())
}
}
/// Trait used to map ONNX Runtime types to Rust types.
pub trait TensorDataToType: Sized + fmt::Debug + Clone {
/// The tensor element type that this type can extract from.
fn tensor_element_data_type() -> TensorElementDataType;
/// Extract an `ArrayView` from the ORT-owned tensor.
fn extract_data<'t, D>(shape: D, tensor_element_len: usize, tensor_ptr: rc::Rc<TensorPointerHolder>) -> OrtResult<TensorData<'t, Self, D>>
where
D: ndarray::Dimension;
}
/// Represents the possible ways tensor data can be accessed.
///
/// This should only be used internally.
#[derive(Debug)]
pub enum TensorData<'t, T, D>
where
D: ndarray::Dimension
{
/// Data residing in ONNX Runtime's tensor, in which case the `'t` lifetime is what makes this valid.
/// This is used for data types whose in-memory form from ONNX Runtime is compatible with Rust's, like
/// primitive numeric types.
TensorPtr {
/// The pointer ONNX Runtime produced. Kept alive so that `array_view` is valid.
ptr: rc::Rc<TensorPointerHolder>,
/// A view into `ptr`.
array_view: ndarray::ArrayView<'t, T, D>
},
/// String data is output differently by ONNX, and is of course also variable size, so it cannot
/// use the same simple pointer representation.
// Since `'t` outlives this struct, the 't lifetime is more than we need, but no harm done.
Strings {
/// Owned Strings copied out of ONNX Runtime's output.
strings: ndarray::Array<T, D>
}
}
/// Implements [`TensorDataToType`] for primitives which can use `GetTensorMutableData`.
macro_rules! impl_prim_type_from_ort_trait {
($type_: ty, $variant: ident) => {
impl TensorDataToType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
}
fn extract_data<'t, D>(shape: D, _tensor_element_len: usize, tensor_ptr: rc::Rc<TensorPointerHolder>) -> OrtResult<TensorData<'t, Self, D>>
where
D: ndarray::Dimension
{
extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| TensorData::TensorPtr { ptr: tensor_ptr, array_view: v })
}
}
};
}
/// Construct an [`ndarray::ArrayView`] for an ORT tensor.
///
/// Only to be used on types whose Rust in-memory representation matches ONNX Runtime's (e.g. primitive numeric types
/// like u32)
fn extract_primitive_array<'t, D, T: TensorDataToType>(shape: D, tensor: *mut sys::OrtValue) -> OrtResult<ndarray::ArrayView<'t, T, D>>
where
D: ndarray::Dimension
{
// Get pointer to output tensor values
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr as *mut *mut std::ffi::c_void;
ortsys![unsafe GetTensorMutableData(tensor, output_array_ptr_ptr_void) -> OrtError::GetTensorMutableData; nonNull(output_array_ptr)];
let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) };
Ok(array_view)
}
#[cfg(feature = "half")]
impl_prim_type_from_ort_trait!(half::f16, Float16);
#[cfg(feature = "half")]
impl_prim_type_from_ort_trait!(half::bf16, Bfloat16);
impl_prim_type_from_ort_trait!(f32, Float32);
impl_prim_type_from_ort_trait!(f64, Float64);
impl_prim_type_from_ort_trait!(u8, Uint8);
impl_prim_type_from_ort_trait!(u16, Uint16);
impl_prim_type_from_ort_trait!(u32, Uint32);
impl_prim_type_from_ort_trait!(u64, Uint64);
impl_prim_type_from_ort_trait!(i8, Int8);
impl_prim_type_from_ort_trait!(i16, Int16);
impl_prim_type_from_ort_trait!(i32, Int32);
impl_prim_type_from_ort_trait!(i64, Int64);
impl_prim_type_from_ort_trait!(bool, Bool);
impl TensorDataToType for String {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}
fn extract_data<'t, D: ndarray::Dimension>(
shape: D,
tensor_element_len: usize,
tensor_ptr: rc::Rc<TensorPointerHolder>
) -> OrtResult<TensorData<'t, Self, D>> {
// Total length of string data, not including \0 suffix
let mut total_length = 0;
ortsys![unsafe GetStringTensorDataLength(tensor_ptr.tensor_ptr, &mut total_length) -> OrtError::GetStringTensorDataLength];
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
// don't seem to be written to in practice either.
// If the string data actually did go farther, it would panic below when using the offset
// data to get slices for each string.
let mut string_contents = vec![0u8; total_length as _];
// one extra slot so that the total length can go in the last one, making all per-string
// length calculations easy
let mut offsets = vec![0; tensor_element_len + 1];
ortsys![unsafe GetStringTensorContent(tensor_ptr.tensor_ptr, string_contents.as_mut_ptr() as *mut ffi::c_void, total_length, offsets.as_mut_ptr(), tensor_element_len as _) -> OrtError::GetStringTensorContent];
// final offset = overall length so that per-string length calculations work for the last string
debug_assert_eq!(0, offsets[tensor_element_len]);
offsets[tensor_element_len] = total_length;
let strings = offsets
// offsets has 1 extra offset past the end so that all windows work
.windows(2)
.map(|w| {
let slice = &string_contents[w[0] as _..w[1] as _];
String::from_utf8(slice.into())
})
.collect::<result::Result<Vec<String>, string::FromUtf8Error>>()
.map_err(OrtError::StringFromUtf8Error)?;
let array = ndarray::Array::from_shape_vec(shape, strings).expect("Shape extracted from tensor didn't match tensor contents");
Ok(TensorData::Strings { strings: array })
}
}

View File

@@ -0,0 +1,125 @@
use ndarray::{Array, ArrayBase};
/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
/// with useful tensor operations.
///
/// # Generic
///
/// The trait is generic over:
/// * `S`: [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)'s data container
/// * `T`: Type contained inside the tensor (for example `f32`)
/// * `D`: Tensor's dimension ([`ndarray::Dimension`](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html))
pub trait NdArrayTensor<S, T, D> {
/// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis
///
/// # Trait Bounds
///
/// The function is generic and thus has some trait bounds:
/// * `D: ndarray::RemoveAxis`: The summation over an axis reduces the dimension of the tensor. A 0-D tensor thus
/// cannot have a softmax calculated.
/// * `S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>`: The storage of the tensor can be an owned
/// array ([`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)) or an array view
/// ([`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)).
/// * `<S as ndarray::RawData>::Elem: std::clone::Clone`: The elements of the tensor must be `Clone`.
/// * `T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign`: The elements of the tensor must be workable
/// as floats and must support `-=` and `/=` operations.
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
<S as ndarray::RawData>::Elem: std::clone::Clone,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
}
impl<S, T, D> NdArrayTensor<S, T, D> for ArrayBase<S, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
<S as ndarray::RawData>::Elem: std::clone::Clone,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign
{
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
let mut new_array: Array<T, D> = self.to_owned();
// FIXME: Change to non-overflowing formula
// e = np.exp(A - np.sum(A, axis=1, keepdims=True))
// np.exp(a) / np.sum(np.exp(a))
new_array.map_inplace(|v| *v = v.exp());
let sum = new_array.sum_axis(axis).insert_axis(axis);
new_array /= &sum;
new_array
}
}
#[cfg(test)]
mod tests {
use ndarray::{arr1, arr2, arr3};
use test_log::test;
use super::*;
#[test]
fn softmax_1d() {
let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]);
let expected_softmax = arr1(&[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]);
let softmax = array.softmax(ndarray::Axis(0));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
#[test]
fn softmax_2d() {
let array = arr2(&[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]]);
let expected_softmax = arr2(&[
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
]);
let softmax = array.softmax(ndarray::Axis(1));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
#[test]
fn softmax_3d() {
let array = arr3(&[
[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]],
[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]],
[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]]
]);
let expected_softmax = arr3(&[
[
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
],
[
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
],
[
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
]
]);
let softmax = array.softmax(ndarray::Axis(2));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
}

View File

@@ -0,0 +1,186 @@
use std::{fmt::Debug, ops::Deref, ptr, rc};
use ndarray::ArrayView;
use tracing::debug;
use super::{TensorData, TensorDataToType, TensorElementDataType};
use crate::{memory::MemoryInfo, ortsys, sys, OrtError, OrtResult};
/// A wrapper around a tensor produced by onnxruntime inference.
///
/// Since different outputs for the same model can have different types, this type is used to allow
/// the user to dynamically query each output's type and extract the appropriate tensor type with
/// [try_extract].
#[derive(Debug)]
pub struct DynOrtTensor<'m, D>
where
D: ndarray::Dimension
{
tensor_ptr_holder: rc::Rc<TensorPointerHolder>,
#[allow(dead_code)]
memory_info: &'m MemoryInfo,
shape: D,
tensor_element_len: usize,
data_type: TensorElementDataType
}
impl<'m, D> DynOrtTensor<'m, D>
where
D: ndarray::Dimension
{
pub(crate) fn new(
tensor_ptr: *mut sys::OrtValue,
memory_info: &'m MemoryInfo,
shape: D,
tensor_element_len: usize,
data_type: TensorElementDataType
) -> DynOrtTensor<'m, D> {
DynOrtTensor {
tensor_ptr_holder: rc::Rc::from(TensorPointerHolder { tensor_ptr }),
memory_info,
shape,
tensor_element_len,
data_type
}
}
/// The ONNX data type this tensor contains.
pub fn data_type(&self) -> TensorElementDataType {
self.data_type
}
/// Extract a tensor containing `T`.
///
/// Where the type permits it, the tensor will be a view into existing memory.
///
/// # Errors
///
/// An error will be returned if `T`'s ONNX type doesn't match this tensor's type, or if an
/// onnxruntime error occurs.
pub fn try_extract<'t, T>(&self) -> OrtResult<OrtOwnedTensor<'t, T, D>>
where
T: TensorDataToType + Clone + Debug,
'm: 't, // mem info outlives tensor
D: 't
{
if self.data_type != T::tensor_element_data_type() {
Err(OrtError::DataTypeMismatch {
actual: self.data_type,
requested: T::tensor_element_data_type()
})
} else {
// Note: Both tensor and array will point to the same data, nothing is copied.
// As such, there is no need to free the pointer used to create the ArrayView.
assert_ne!(self.tensor_ptr_holder.tensor_ptr, ptr::null_mut());
let mut is_tensor = 0;
ortsys![unsafe IsTensor(self.tensor_ptr_holder.tensor_ptr, &mut is_tensor) -> OrtError::FailedTensorCheck];
assert_eq!(is_tensor, 1);
let data = T::extract_data(self.shape.clone(), self.tensor_element_len, rc::Rc::clone(&self.tensor_ptr_holder))?;
Ok(OrtOwnedTensor { data })
}
}
}
/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference.
///
/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method.
/// It is not meant to be created directly.
///
/// The tensor hosts an [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)
/// of the data on the C side. This allows manipulation on the Rust side using `ndarray` without copying the data.
///
/// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to
/// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
#[derive(Debug)]
pub struct OrtOwnedTensor<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
data: TensorData<'t, T, D>
}
impl<'t, T, D> OrtOwnedTensor<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension + 't
{
/// Produce a [`ViewHolder`] for the underlying data.
pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D>
where
't: 's // tensor ptr can outlive the TensorData
{
ViewHolder::new(&self.data)
}
}
/// An intermediate step on the way to an ArrayView.
// Since Deref has to produce a reference, and the referent can't be a local in deref(), it must
// be a field in a struct. This struct exists only to hold that field.
// Its lifetime 's is bound to the TensorData its view was created around, not the underlying tensor
// pointer, since in the case of strings the data is the Array in the TensorData, not the pointer.
pub struct ViewHolder<'s, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
array_view: ndarray::ArrayView<'s, T, D>
}
impl<'s, T, D> ViewHolder<'s, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D>
where
't: 's // underlying tensor ptr lives at least as long as TensorData
{
match data {
TensorData::TensorPtr { array_view, .. } => ViewHolder {
// we already have a view, but creating a view from a view is cheap
array_view: array_view.view()
},
TensorData::Strings { strings } => ViewHolder {
// This view creation has to happen here, not at new()'s callsite, because
// a field can't be a reference to another field in the same struct. Thus, we have
// this separate struct to hold the view that refers to the `Array`.
array_view: strings.view()
}
}
}
}
impl<'t, T, D> Deref for ViewHolder<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
type Target = ArrayView<'t, T, D>;
fn deref(&self) -> &Self::Target {
&self.array_view
}
}
/// Holds on to a tensor pointer until dropped.
///
/// This allows for creating an [`OrtOwnedTensor`] from a [`DynOrtTensor`] without consuming `self`, which would prevent
/// retrying extraction and avoids awkward interaction with the outputs `Vec`. It also avoids requiring `OrtOwnedTensor`
/// to keep a reference to `DynOrtTensor`, which would be inconvenient.
#[derive(Debug)]
pub struct TensorPointerHolder {
pub(crate) tensor_ptr: *mut sys::OrtValue
}
impl Drop for TensorPointerHolder {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping OrtOwnedTensor.");
ortsys![unsafe ReleaseValue(self.tensor_ptr)];
self.tensor_ptr = ptr::null_mut();
}
}

263
src/tensor/ort_tensor.rs Normal file
View File

@@ -0,0 +1,263 @@
use std::{ffi, fmt::Debug, ops::Deref};
use ndarray::Array;
use tracing::{debug, error};
use crate::{
error::assert_non_null_pointer,
memory::MemoryInfo,
ortsys, sys,
tensor::{ndarray_tensor::NdArrayTensor, IntoTensorElementDataType, TensorElementDataType},
OrtError, OrtResult
};
/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
///
/// This tensor bounds the ONNX Runtime to `ndarray`; it is used to copy an
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) to the runtime's memory.
///
/// **NOTE**: The type is not meant to be used directly, use an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
/// instead.
#[derive(Debug)]
pub struct OrtTensor<'t, T, D>
where
T: IntoTensorElementDataType + Debug + Clone,
D: ndarray::Dimension
{
pub(crate) c_ptr: *mut sys::OrtValue,
array: Array<T, D>,
#[allow(dead_code)]
memory_info: &'t MemoryInfo
}
impl<'t, T, D> OrtTensor<'t, T, D>
where
T: IntoTensorElementDataType + Debug + Clone,
D: ndarray::Dimension
{
pub(crate) fn from_array<'m>(memory_info: &'m MemoryInfo, allocator_ptr: *mut sys::OrtAllocator, array: &Array<T, D>) -> OrtResult<OrtTensor<'t, T, D>>
where
'm: 't // 'm outlives 't
{
let mut array = array.to_owned();
// where onnxruntime will write the tensor data to
let mut tensor_ptr: *mut sys::OrtValue = std::ptr::null_mut();
let tensor_ptr_ptr: *mut *mut sys::OrtValue = &mut tensor_ptr;
let shape: Vec<i64> = array.shape().iter().map(|d: &usize| *d as i64).collect();
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = array.shape().len();
match T::tensor_element_data_type() {
TensorElementDataType::Float32
| TensorElementDataType::Uint8
| TensorElementDataType::Int8
| TensorElementDataType::Uint16
| TensorElementDataType::Int16
| TensorElementDataType::Int32
| TensorElementDataType::Int64
| TensorElementDataType::Float64
| TensorElementDataType::Uint32
| TensorElementDataType::Uint64 => {
// primitive data is already suitably laid out in memory; provide it to
// onnxruntime as is
let tensor_values_ptr: *mut std::ffi::c_void = array.as_mut_ptr() as *mut std::ffi::c_void;
assert_non_null_pointer(tensor_values_ptr, "TensorValues")?;
ortsys![
unsafe CreateTensorWithDataAsOrtValue(
memory_info.ptr,
tensor_values_ptr,
(array.len() * std::mem::size_of::<T>()) as _,
shape_ptr,
shape_len as _,
T::tensor_element_data_type().into(),
tensor_ptr_ptr
) -> OrtError::CreateTensorWithData;
nonNull(tensor_ptr)
];
let mut is_tensor = 0;
ortsys![unsafe IsTensor(tensor_ptr, &mut is_tensor) -> OrtError::FailedTensorCheck];
assert_eq!(is_tensor, 1);
}
#[cfg(feature = "half")]
TensorElementDataType::Bfloat16 | TensorElementDataType::Float16 => {
// f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime
let tensor_values_ptr: *mut std::ffi::c_void = array.as_mut_ptr() as *mut std::ffi::c_void;
assert_non_null_pointer(tensor_values_ptr, "TensorValues")?;
ortsys![
unsafe CreateTensorWithDataAsOrtValue(
memory_info.ptr,
tensor_values_ptr,
array.len() * std::mem::size_of::<T>(),
shape_ptr,
shape_len,
T::tensor_element_data_type().into(),
tensor_ptr_ptr
) -> OrtError::CreateTensorWithData;
nonNull(tensor_ptr)
];
let mut is_tensor = 0;
ortsys![unsafe IsTensor(tensor_ptr, &mut is_tensor) -> OrtError::FailedTensorCheck];
assert_eq!(is_tensor, 1);
}
TensorElementDataType::String => {
// create tensor without data -- data is filled in later
ortsys![
unsafe CreateTensorAsOrtValue(allocator_ptr, shape_ptr, shape_len as _, T::tensor_element_data_type().into(), tensor_ptr_ptr)
-> OrtError::CreateTensor
];
// create null-terminated copies of each string, as per `FillStringTensor` docs
let null_terminated_copies: Vec<ffi::CString> = array
.iter()
.map(|elt| {
let slice = elt.try_utf8_bytes().expect("String data type must provide utf8 bytes");
ffi::CString::new(slice)
})
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(OrtError::FfiStringNull)?;
let string_pointers = null_terminated_copies.iter().map(|cstring| cstring.as_ptr()).collect::<Vec<_>>();
ortsys![unsafe FillStringTensor(tensor_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> OrtError::FillStringTensor];
}
_ => unimplemented!("Tensor element data type {:?} not yet implemented", T::tensor_element_data_type())
}
assert_non_null_pointer(tensor_ptr, "Tensor")?;
Ok(OrtTensor {
c_ptr: tensor_ptr,
array,
memory_info
})
}
}
impl<'t, T, D> Deref for OrtTensor<'t, T, D>
where
T: IntoTensorElementDataType + Debug + Clone,
D: ndarray::Dimension
{
type Target = Array<T, D>;
fn deref(&self) -> &Self::Target {
&self.array
}
}
impl<'t, T, D> Drop for OrtTensor<'t, T, D>
where
T: IntoTensorElementDataType + Debug + Clone,
D: ndarray::Dimension
{
#[tracing::instrument]
fn drop(&mut self) {
// We need to let the C part free
debug!("Dropping Tensor.");
if self.c_ptr.is_null() {
error!("Null pointer, not calling free.");
} else {
ortsys![unsafe ReleaseValue(self.c_ptr)];
}
self.c_ptr = std::ptr::null_mut();
}
}
impl<'t, T, D> OrtTensor<'t, T, D>
where
T: IntoTensorElementDataType + Debug + Clone,
D: ndarray::Dimension
{
/// Apply a softmax on the specified axis
pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
where
D: ndarray::RemoveAxis,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign
{
self.array.softmax(axis)
}
}
#[cfg(test)]
mod tests {
use std::ptr;
use ndarray::{arr0, arr1, arr2, arr3};
use test_log::test;
use super::*;
use crate::{AllocatorType, MemType};
#[test]
fn orttensor_from_array_0d_i32() -> OrtResult<()> {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let array = arr0::<i32>(123);
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), &array)?;
let expected_shape: &[usize] = &[];
assert_eq!(tensor.shape(), expected_shape);
Ok(())
}
#[test]
fn orttensor_from_array_1d_i32() -> OrtResult<()> {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let array = arr1(&[1_i32, 2, 3, 4, 5, 6]);
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), &array)?;
let expected_shape: &[usize] = &[6];
assert_eq!(tensor.shape(), expected_shape);
Ok(())
}
#[test]
fn orttensor_from_array_2d_i32() -> OrtResult<()> {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]);
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), &array)?;
assert_eq!(tensor.shape(), &[2, 6]);
Ok(())
}
#[test]
fn orttensor_from_array_3d_i32() -> OrtResult<()> {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let array = arr3(&[
[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]],
[[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]],
[[25, 26, 27, 28, 29, 30], [31, 32, 33, 34, 35, 36]]
]);
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), &array)?;
assert_eq!(tensor.shape(), &[3, 2, 6]);
Ok(())
}
#[test]
fn orttensor_from_array_1d_string() -> OrtResult<()> {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let array = arr1(&[String::from("foo"), String::from("bar"), String::from("baz")]);
let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator()?, &array)?;
assert_eq!(tensor.shape(), &[3]);
Ok(())
}
#[test]
fn orttensor_from_array_3d_str() -> OrtResult<()> {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let array = arr3(&[[["1", "2", "3"], ["4", "5", "6"]], [["7", "8", "9"], ["10", "11", "12"]]]);
let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator()?, &array)?;
assert_eq!(tensor.shape(), &[2, 2, 3]);
Ok(())
}
fn ort_default_allocator() -> OrtResult<*mut sys::OrtAllocator> {
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
// this default non-arena allocator doesn't need to be deallocated
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator];
Ok(allocator_ptr)
}
}

View File

@@ -0,0 +1,157 @@
use std::fmt::Debug;
use ndarray::{Array, IxDyn};
use crate::{memory::MemoryInfo, sys, tensor::OrtTensor, OrtResult};
/// Trait used for constructing inputs with multiple element types from [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
pub trait FromArray<T> {
/// Wrap [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) into enum with specific dtype variants.
fn from_array(array: Array<T, IxDyn>) -> InputTensor;
}
macro_rules! impl_convert_trait {
($type_:ty, $variant:expr) => {
impl FromArray<$type_> for InputTensor {
fn from_array(array: Array<$type_, IxDyn>) -> InputTensor {
$variant(array)
}
}
};
}
/// Input tensor enum with tensor element type as a variant.
///
/// Required for supplying inputs with different types
#[derive(Debug)]
#[allow(missing_docs)]
pub enum InputTensor {
FloatTensor(Array<f32, IxDyn>),
#[cfg(feature = "half")]
Float16Tensor(Array<half::f16, IxDyn>),
#[cfg(feature = "half")]
Bfloat16Tensor(Array<half::bf16, IxDyn>),
Uint8Tensor(Array<u8, IxDyn>),
Int8Tensor(Array<i8, IxDyn>),
Uint16Tensor(Array<u16, IxDyn>),
Int16Tensor(Array<i16, IxDyn>),
Int32Tensor(Array<i32, IxDyn>),
Int64Tensor(Array<i64, IxDyn>),
DoubleTensor(Array<f64, IxDyn>),
Uint32Tensor(Array<u32, IxDyn>),
Uint64Tensor(Array<u64, IxDyn>),
StringTensor(Array<String, IxDyn>)
}
/// This tensor is used to copy an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
/// from InputTensor to the runtime's memory with support to multiple input tensor types.
///
/// **NOTE**: The type is not meant to be used directly, use an InputTensor constructed from
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) instead.
#[derive(Debug)]
#[allow(missing_docs)]
pub enum InputOrtTensor<'t> {
FloatTensor(OrtTensor<'t, f32, IxDyn>),
#[cfg(feature = "half")]
Float16Tensor(OrtTensor<'t, half::f16, IxDyn>),
#[cfg(feature = "half")]
Bfloat16Tensor(OrtTensor<'t, half::bf16, IxDyn>),
Uint8Tensor(OrtTensor<'t, u8, IxDyn>),
Int8Tensor(OrtTensor<'t, i8, IxDyn>),
Uint16Tensor(OrtTensor<'t, u16, IxDyn>),
Int16Tensor(OrtTensor<'t, i16, IxDyn>),
Int32Tensor(OrtTensor<'t, i32, IxDyn>),
Int64Tensor(OrtTensor<'t, i64, IxDyn>),
DoubleTensor(OrtTensor<'t, f64, IxDyn>),
Uint32Tensor(OrtTensor<'t, u32, IxDyn>),
Uint64Tensor(OrtTensor<'t, u64, IxDyn>),
StringTensor(OrtTensor<'t, String, IxDyn>)
}
impl InputTensor {
/// Get shape of the underlying array.
pub fn shape(&self) -> &[usize] {
match self {
InputTensor::FloatTensor(x) => x.shape(),
#[cfg(feature = "half")]
InputTensor::Float16Tensor(x) => x.shape(),
#[cfg(feature = "half")]
InputTensor::Bfloat16Tensor(x) => x.shape(),
InputTensor::Uint8Tensor(x) => x.shape(),
InputTensor::Int8Tensor(x) => x.shape(),
InputTensor::Uint16Tensor(x) => x.shape(),
InputTensor::Int16Tensor(x) => x.shape(),
InputTensor::Int32Tensor(x) => x.shape(),
InputTensor::Int64Tensor(x) => x.shape(),
InputTensor::DoubleTensor(x) => x.shape(),
InputTensor::Uint32Tensor(x) => x.shape(),
InputTensor::Uint64Tensor(x) => x.shape(),
InputTensor::StringTensor(x) => x.shape()
}
}
}
impl_convert_trait!(f32, InputTensor::FloatTensor);
#[cfg(feature = "half")]
impl_convert_trait!(half::f16, InputTensor::Float16Tensor);
#[cfg(feature = "half")]
impl_convert_trait!(half::bf16, InputTensor::Bfloat16Tensor);
impl_convert_trait!(u8, InputTensor::Uint8Tensor);
impl_convert_trait!(i8, InputTensor::Int8Tensor);
impl_convert_trait!(u16, InputTensor::Uint16Tensor);
impl_convert_trait!(i16, InputTensor::Int16Tensor);
impl_convert_trait!(i32, InputTensor::Int32Tensor);
impl_convert_trait!(i64, InputTensor::Int64Tensor);
impl_convert_trait!(f64, InputTensor::DoubleTensor);
impl_convert_trait!(u32, InputTensor::Uint32Tensor);
impl_convert_trait!(u64, InputTensor::Uint64Tensor);
impl_convert_trait!(String, InputTensor::StringTensor);
impl<'t> InputOrtTensor<'t> {
pub(crate) fn from_input_tensor<'m, 'i>(
memory_info: &'m MemoryInfo,
allocator_ptr: *mut sys::OrtAllocator,
input_tensor: &'i InputTensor
) -> OrtResult<InputOrtTensor<'t>>
where
'm: 't
{
match input_tensor {
InputTensor::FloatTensor(array) => Ok(InputOrtTensor::FloatTensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
#[cfg(feature = "half")]
InputTensor::Float16Tensor(array) => Ok(InputOrtTensor::Float16Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
#[cfg(feature = "half")]
InputTensor::Bfloat16Tensor(array) => Ok(InputOrtTensor::Bfloat16Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Uint8Tensor(array) => Ok(InputOrtTensor::Uint8Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Int8Tensor(array) => Ok(InputOrtTensor::Int8Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Uint16Tensor(array) => Ok(InputOrtTensor::Uint16Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Int16Tensor(array) => Ok(InputOrtTensor::Int16Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Int32Tensor(array) => Ok(InputOrtTensor::Int32Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Int64Tensor(array) => Ok(InputOrtTensor::Int64Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::DoubleTensor(array) => Ok(InputOrtTensor::DoubleTensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Uint32Tensor(array) => Ok(InputOrtTensor::Uint32Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::Uint64Tensor(array) => Ok(InputOrtTensor::Uint64Tensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?)),
InputTensor::StringTensor(array) => Ok(InputOrtTensor::StringTensor(OrtTensor::from_array(memory_info, allocator_ptr, array)?))
}
}
pub(crate) fn c_ptr(&self) -> *const sys::OrtValue {
match self {
InputOrtTensor::FloatTensor(x) => x.c_ptr,
#[cfg(feature = "half")]
InputOrtTensor::Float16Tensor(x) => x.c_ptr,
#[cfg(feature = "half")]
InputOrtTensor::Bfloat16Tensor(x) => x.c_ptr,
InputOrtTensor::Uint8Tensor(x) => x.c_ptr,
InputOrtTensor::Int8Tensor(x) => x.c_ptr,
InputOrtTensor::Uint16Tensor(x) => x.c_ptr,
InputOrtTensor::Int16Tensor(x) => x.c_ptr,
InputOrtTensor::Int32Tensor(x) => x.c_ptr,
InputOrtTensor::Int64Tensor(x) => x.c_ptr,
InputOrtTensor::DoubleTensor(x) => x.c_ptr,
InputOrtTensor::Uint32Tensor(x) => x.c_ptr,
InputOrtTensor::Uint64Tensor(x) => x.c_ptr,
InputOrtTensor::StringTensor(x) => x.c_ptr
}
}
}

1
src/wrapper.h Normal file
View File

@@ -0,0 +1 @@
#include <onnxruntime_c_api.h>

50001
tests/data/gpt2/merges.txt Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

BIN
tests/data/mnist_5.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 555 B

BIN
tests/data/mushroom.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

BIN
tests/data/upsample.onnx Normal file

Binary file not shown.

263
tests/onnx.rs Normal file
View File

@@ -0,0 +1,263 @@
use std::{
fs,
io::{self, BufRead, BufReader},
path::Path,
time::Duration
};
use ort::error::OrtDownloadError;
mod download {
use std::sync::Arc;
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel, Rgb};
use ndarray::s;
use ort::{
download::vision::{DomainBasedImageClassification, ImageClassification},
environment::Environment,
tensor::{ndarray_tensor::NdArrayTensor, DynOrtTensor, FromArray, InputTensor, OrtOwnedTensor},
GraphOptimizationLevel, LoggingLevel, OrtResult, SessionBuilder
};
use test_log::test;
use super::*;
#[test]
fn squeezenet_mushroom() -> OrtResult<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Arc::new(
Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
);
let mut session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(ImageClassification::SqueezeNet)
.expect("Could not download model from file");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "main");
assert_eq!(metadata.producer()?, "");
let class_labels = get_imagenet_labels()?;
let input0_shape: Vec<usize> = session.inputs[0].dimensions().map(|d| d.unwrap()).collect();
let output0_shape: Vec<usize> = session.outputs[0].dimensions().map(|d| d.unwrap()).collect();
assert_eq!(input0_shape, [1, 3, 224, 224]);
assert_eq!(output0_shape, [1, 1000]);
// Load image and resize to model's shape, converting to RGB format
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))
.unwrap()
.resize(input0_shape[2] as u32, input0_shape[3] as u32, FilterType::Nearest)
.to_rgb8();
// Python:
// # image[y, x, RGB]
// # x==0 --> left
// # y==0 --> top
// See https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb
// for pre-processing image.
// WARNING: Note order of declaration of arguments: (_,c,j,i)
let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
// range [0, 255] -> range [0, 1]
(channels[c] as f32) / 255.0
});
// Normalize channels to mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];
for c in 0..3 {
let mut channel_array = array.slice_mut(s![0, c, .., ..]);
channel_array -= mean[c];
channel_array /= std[c];
}
// Batch of 1
let input_tensor_values = [InputTensor::from_array(array.into_dyn())];
// Perform the inference
let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = session.run(input_tensor_values)?;
// Downloaded model does not have a softmax as final layer; call softmax on second axis
// and iterate on resulting probabilities, creating an index to later access labels.
let output: OrtOwnedTensor<_, _> = outputs[0].try_extract()?;
let mut probabilities: Vec<(usize, f32)> = output.view().softmax(ndarray::Axis(1)).iter().copied().enumerate().collect::<Vec<_>>();
// Sort probabilities so highest is at beginning of vector.
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
assert_eq!(class_labels[probabilities[0].0], "n07734744 mushroom", "Expecting class for {} to be a mushroom", IMAGE_TO_LOAD);
assert_eq!(probabilities[0].0, 947, "Expecting class for {} to be a mushroom (index 947 in labels file)", IMAGE_TO_LOAD);
Ok(())
}
#[test]
fn mnist_5() -> OrtResult<()> {
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
let environment = Arc::new(
Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
);
let mut session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(DomainBasedImageClassification::Mnist)
.expect("Could not download model from file");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "CNTKGraph");
assert_eq!(metadata.producer()?, "CNTK");
let input0_shape: Vec<usize> = session.inputs[0].dimensions().map(|d| d.unwrap()).collect();
let output0_shape: Vec<usize> = session.outputs[0].dimensions().map(|d| d.unwrap()).collect();
assert_eq!(input0_shape, [1, 1, 28, 28]);
assert_eq!(output0_shape, [1, 10]);
// Load image and resize to model's shape, converting to RGB format
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))
.unwrap()
.resize(input0_shape[2] as u32, input0_shape[3] as u32, FilterType::Nearest)
.to_luma8();
let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
// range [0, 255] -> range [0, 1]
(channels[c] as f32) / 255.0
});
// Batch of 1
let input_tensor_values = [InputTensor::from_array(array.into_dyn())];
// Perform the inference
let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = session.run(input_tensor_values)?;
let output: OrtOwnedTensor<_, _> = outputs[0].try_extract()?;
let mut probabilities: Vec<(usize, f32)> = output.view().softmax(ndarray::Axis(1)).iter().copied().enumerate().collect::<Vec<_>>();
// Sort probabilities so highest is at beginning of vector.
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
assert_eq!(probabilities[0].0, 5, "Expecting class for {} is '5' (not {})", IMAGE_TO_LOAD, probabilities[0].0);
Ok(())
}
/// This test verifies that dynamically sized inputs and outputs work. It loads and runs
/// upsample.onnx, which was produced via:
///
/// ```python
/// import subprocess
/// from tensorflow import keras
///
/// m = keras.Sequential([
/// keras.layers.UpSampling2D(size=2)
/// ])
/// m.build(input_shape=(None, None, None, 3))
/// m.summary()
/// m.save('saved_model')
///
/// subprocess.check_call([
/// 'python', '-m', 'tf2onnx.convert',
/// '--saved-model', 'saved_model',
/// '--opset', '12',
/// '--output', 'upsample.onnx'
/// ])
/// ```
#[test]
fn upsample() -> OrtResult<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Arc::new(
Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()?
);
let mut session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx"))
.expect("Could not open model from file");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "tf2onnx");
assert_eq!(metadata.producer()?, "tf2onnx");
assert_eq!(session.inputs[0].dimensions().collect::<Vec<_>>(), [None, None, None, Some(3)]);
assert_eq!(session.outputs[0].dimensions().collect::<Vec<_>>(), [None, None, None, Some(3)]);
// Load image, converting to RGB format
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD))
.unwrap()
.to_rgb8();
let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
// range [0, 255] -> range [0, 1]
(channels[c] as f32) / 255.0
});
// Just one input
let input_tensor_values = [InputTensor::from_array(array.into_dyn())];
// Perform the inference
let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = session.run(input_tensor_values)?;
assert_eq!(outputs.len(), 1);
let output: OrtOwnedTensor<'_, f32, ndarray::Dim<ndarray::IxDynImpl>> = outputs[0].try_extract()?;
// The image should have doubled in size
assert_eq!(output.view().shape(), [1, 448, 448, 3]);
Ok(())
}
}
fn get_imagenet_labels() -> Result<Vec<String>, OrtDownloadError> {
// Download the ImageNet class labels, matching SqueezeNet's classes.
let labels_path = Path::new(env!("CARGO_TARGET_TMPDIR")).join("synset.txt");
if !labels_path.exists() {
let url = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt";
println!("Downloading {:?} to {:?}...", url, labels_path);
let resp = ureq::get(url)
.timeout(Duration::from_secs(180)) // 3 minutes
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::FetchError)?;
assert!(resp.has("Content-Length"));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
println!("Downloading {} bytes...", len);
let mut reader = resp.into_reader();
let f = fs::File::create(&labels_path).unwrap();
let mut writer = io::BufWriter::new(f);
let bytes_io_count = io::copy(&mut reader, &mut writer).unwrap();
assert_eq!(bytes_io_count, len as u64);
}
let file = BufReader::new(fs::File::open(labels_path).unwrap());
file.lines().map(|line| line.map_err(OrtDownloadError::IoError)).collect()
}

View File

@@ -0,0 +1,7 @@
SET(CMAKE_SYSTEM_NAME Linux)
SET(CMAKE_C_COMPILER aarch64-linux-gnu-gcc)
SET(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++)
SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
SET(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)

View File

@@ -0,0 +1,7 @@
SET(CMAKE_SYSTEM_NAME Windows)
SET(CMAKE_C_COMPILER x86_64-w64-mingw32-clang)
SET(CMAKE_CXX_COMPILER x86_64-w64-mingw32-clang++)
SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
SET(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)

View File

@@ -0,0 +1,7 @@
SET(CMAKE_SYSTEM_NAME Windows)
SET(CMAKE_C_COMPILER x86_64-w64-mingw32-gcc)
SET(CMAKE_CXX_COMPILER x86_64-w64-mingw32-g++)
SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
SET(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)