mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: ort-web (#450)
This commit is contained in:
@@ -5,6 +5,7 @@ default-members = [ '.' ]
|
||||
exclude = [
|
||||
'backends/candle',
|
||||
'backends/tract',
|
||||
'backends/web',
|
||||
'examples/async-gpt2-api',
|
||||
'examples/cudarc',
|
||||
'examples/custom-ops',
|
||||
|
||||
3
backends/web/.vscode/settings.json
vendored
Normal file
3
backends/web/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"rust-analyzer.cargo.target": "wasm32-unknown-unknown"
|
||||
}
|
||||
44
backends/web/Cargo.toml
Normal file
44
backends/web/Cargo.toml
Normal file
@@ -0,0 +1,44 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "ort-web"
|
||||
description = "ONNX Runtime on the web 🌐 - An alternative backend for ort"
|
||||
version = "0.1.0+1.23"
|
||||
edition = "2024"
|
||||
rust-version = "1.88"
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/pykeio/ort"
|
||||
homepage = "https://ort.pyke.io/backends/web"
|
||||
keywords = [ "machine-learning", "ai", "ml", "web", "wasm" ]
|
||||
categories = [ "algorithms", "mathematics", "science", "web-programming", "wasm" ]
|
||||
authors = [
|
||||
"pyke.io <contact@pyke.io>"
|
||||
]
|
||||
|
||||
[lib]
|
||||
name = "ort_web"
|
||||
path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
js-sys = "0.3"
|
||||
ort = { path = "../../", version = "=2.0.0-rc.10", default-features = false, features = [ "alternative-backend" ] }
|
||||
ort-sys = { path = "../../ort-sys", version = "=2.0.0-rc.10", default-features = false, features = [ "disable-linking" ] }
|
||||
serde = { version = "1.0", features = [ "derive" ] }
|
||||
serde-wasm-bindgen = "0.6"
|
||||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4"
|
||||
|
||||
[dependencies.web-sys]
|
||||
version = "0.3"
|
||||
features = [
|
||||
"console",
|
||||
"ImageData",
|
||||
"HtmlImageElement",
|
||||
"ImageBitmap",
|
||||
"WebGlTexture",
|
||||
"GpuBuffer"
|
||||
]
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(web_sys_unstable_apis)' ] }
|
||||
152
backends/web/_loader.js
Normal file
152
backends/web/_loader.js
Normal file
@@ -0,0 +1,152 @@
|
||||
const INIT_SYMBOL = Symbol('@ort-web.init');
|
||||
|
||||
const FEATURES_NONE = 0;
|
||||
const FEATURES_WEBGL = 1 << 0;
|
||||
const FEATURES_WEBGPU = 1 << 1;
|
||||
const FEATURES_ALL = FEATURES_WEBGL | FEATURES_WEBGPU;
|
||||
|
||||
/**
|
||||
* @typedef {Object} Dist
|
||||
* @property {string} baseUrl
|
||||
* @property {string} scriptName
|
||||
* @property {string | null} [binaryName]
|
||||
* @property {string | null} [wrapperName] defaults to `binaryName` s/\.wasm$/.mjs
|
||||
* @property {Record<'main' | 'wrapper' | 'binary', string> | null} integrities
|
||||
*/
|
||||
|
||||
const DEFAULT_DIST_BASE = 'https://cdn.pyke.io/0/pyke:ort-rs/web@1.23.0/';
|
||||
|
||||
/** @type {Record<number, Dist>} */
|
||||
const DEFAULT_DIST = {
|
||||
[FEATURES_NONE]: {
|
||||
baseUrl: DEFAULT_DIST_BASE,
|
||||
scriptName: 'ort.wasm.min.js',
|
||||
binaryName: 'ort-wasm-simd-threaded.wasm',
|
||||
integrities: {
|
||||
main: 'Uvpo3KshAzID7bmsY+Pz2/tiNWwl6Y5XeDTPpktDx73e0o/1TdssZDScTVHxpLYv',
|
||||
wrapper: 'Y/ZaWdP4FERyRvi+anEVDVDDhMJKldzf33TRb2MiCALo054swqCUe6aM/tD8XL6g',
|
||||
binary: '9UMXJFWi2zyn9PbGgXmJjEYM4hu8T8zmqmgxX6zQ08ZmNBOso3IT0cTp3M3oU7DU'
|
||||
}
|
||||
},
|
||||
[FEATURES_WEBGL]: {
|
||||
baseUrl: DEFAULT_DIST_BASE,
|
||||
scriptName: 'ort.webgl.min.js',
|
||||
binaryName: 'ort-wasm-simd-threaded.wasm',
|
||||
integrities: {
|
||||
main: 'pD9jsAlDhP5yhHaVikKM6mXw/E4HPB+4kc/rf3lrMctGWwT0XpIxiTdH/XDHR7Pr',
|
||||
wrapper: 'Y/ZaWdP4FERyRvi+anEVDVDDhMJKldzf33TRb2MiCALo054swqCUe6aM/tD8XL6g',
|
||||
binary: '9UMXJFWi2zyn9PbGgXmJjEYM4hu8T8zmqmgxX6zQ08ZmNBOso3IT0cTp3M3oU7DU'
|
||||
}
|
||||
},
|
||||
[FEATURES_WEBGPU]: {
|
||||
baseUrl: DEFAULT_DIST_BASE,
|
||||
scriptName: 'ort.webgpu.min.js',
|
||||
binaryName: 'ort-wasm-simd-threaded.jsep.wasm',
|
||||
integrities: {
|
||||
main: 'rY/SpyGuo298HuKPNCTIhlm3xc022++95XwJnuGVpKaW4yEzMTTDvgXoRQdiicvj',
|
||||
wrapper: 'Liv6LVoHkWBuJEPAGGmpzPGesXdc9YN5Eu0UaA9a9qChwB0H21V86UFBLhnIBieb',
|
||||
binary: 'jVPVL8reOtRz4+v3ZZAWg8bO5m7HGJr7tsMxmvNae28TztYbHZIk8JXHeZ/82yST'
|
||||
}
|
||||
},
|
||||
[FEATURES_ALL]: {
|
||||
baseUrl: DEFAULT_DIST_BASE,
|
||||
scriptName: 'ort.all.min.js',
|
||||
binaryName: 'ort-wasm-simd-threaded.jsep.wasm',
|
||||
integrities: {
|
||||
main: 'VVNyVdgdgHOM/8agRDy7rVx66N+/9T1vkYzwYtSS/u36YVzaln3cMtxt24ozySvr',
|
||||
wrapper: 'Liv6LVoHkWBuJEPAGGmpzPGesXdc9YN5Eu0UaA9a9qChwB0H21V86UFBLhnIBieb',
|
||||
binary: 'jVPVL8reOtRz4+v3ZZAWg8bO5m7HGJr7tsMxmvNae28TztYbHZIk8JXHeZ/82yST'
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {string} url
|
||||
* @param {'fetch' | 'script' | 'module'} as
|
||||
* @param {string} [type]
|
||||
* @param {string | null} [integrity]
|
||||
*/
|
||||
function preload(url, as, type, integrity) {
|
||||
const el = document.createElement('link');
|
||||
el.href = url;
|
||||
if (as !== 'module') {
|
||||
el.rel = 'preload';
|
||||
el.setAttribute('as', as);
|
||||
} else {
|
||||
el.rel = 'modulepreload';
|
||||
}
|
||||
if (type) {
|
||||
el.setAttribute('type', type);
|
||||
}
|
||||
if (integrity) {
|
||||
el.setAttribute('integrity', `sha384-${integrity}`);
|
||||
}
|
||||
el.setAttribute('crossorigin', 'anonymous');
|
||||
document.head.appendChild(el);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {number} features
|
||||
* @param {Dist} [dist]
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
export function initRuntime(features, dist) {
|
||||
if ('ort' in window && /** @type {any} */(window).ort[INIT_SYMBOL]) {
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
if (!dist) {
|
||||
if (!(features in DEFAULT_DIST)) {
|
||||
return Promise.reject(new Error('Unsupported feature set'));
|
||||
}
|
||||
|
||||
dist = DEFAULT_DIST[features];
|
||||
}
|
||||
|
||||
/** @param {string} file */
|
||||
const relative = file => new URL(file, dist.baseUrl).toString();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
// since the order is load main script -> imports wrapper script -> fetches wasm, now would be a good time to
|
||||
// start fetching those
|
||||
if (dist.binaryName) {
|
||||
preload(
|
||||
relative(dist.binaryName),
|
||||
'fetch',
|
||||
'application/wasm',
|
||||
dist.integrities && dist.integrities.binary
|
||||
);
|
||||
preload(
|
||||
relative(dist.wrapperName || dist.binaryName.replace(/\.wasm$/, '.mjs')),
|
||||
'module',
|
||||
undefined,
|
||||
dist.integrities && dist.integrities.wrapper
|
||||
);
|
||||
}
|
||||
|
||||
const script = document.createElement('script');
|
||||
script.src = new URL(dist.scriptName, dist.baseUrl).toString();
|
||||
if (dist.integrities && dist.integrities.main) {
|
||||
script.setAttribute('integrity', `sha384-${dist.integrities && dist.integrities.main}`);
|
||||
}
|
||||
script.setAttribute('crossorigin', 'anonymous');
|
||||
script.addEventListener('load', () => {
|
||||
if (!('ort' in window)) {
|
||||
return reject(new Error('script loaded but ort not defined'));
|
||||
}
|
||||
|
||||
Object.defineProperty(window.ort, INIT_SYMBOL, {
|
||||
value: true,
|
||||
configurable: false,
|
||||
enumerable: false,
|
||||
writable: false
|
||||
});
|
||||
|
||||
resolve(true);
|
||||
});
|
||||
script.addEventListener('error', e => {
|
||||
reject(e.error);
|
||||
});
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
}
|
||||
52
backends/web/_telemetry.js
Normal file
52
backends/web/_telemetry.js
Normal file
@@ -0,0 +1,52 @@
|
||||
const EVENT_URL = 'https://signal.pyke.io/beacon/9f5be487-d137-455a-9938-2fc7ecaa9de3/vVOv73JqP3iYRqXMBNm';
|
||||
|
||||
const IS_LOCALHOST = /^localhost$|^127(\.[0-9]+){0,2}\.[0-9]+$|^\[::1?\]$/;
|
||||
|
||||
/** @param {Uint8Array<ArrayBuffer>} payload */
|
||||
function track(payload) {
|
||||
if (IS_LOCALHOST.test(location.hostname) || location.protocol === 'file:') {
|
||||
return false;
|
||||
}
|
||||
if (navigator.webdriver || 'Cypress' in window) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return navigator.sendBeacon(EVENT_URL, payload.buffer);
|
||||
}
|
||||
|
||||
/** @param {Uint8Array<ArrayBuffer>[]} chunks */
|
||||
function concat(...chunks) {
|
||||
const concatenated = new Uint8Array(chunks.reduce((a, b) => a + b.byteLength, 0));
|
||||
let offset = 0;
|
||||
for (const chunk of chunks) {
|
||||
concatenated.set(chunk, offset);
|
||||
offset += chunk.byteLength;
|
||||
}
|
||||
return concatenated;
|
||||
}
|
||||
|
||||
/** @param {number} x */
|
||||
function asUint32(x) {
|
||||
const view = new DataView(new ArrayBuffer(4));
|
||||
view.setUint32(0, x, true);
|
||||
return new Uint8Array(view.buffer);
|
||||
}
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
|
||||
let hasInitializedSession = false;
|
||||
export function trackSessionInit() {
|
||||
if (hasInitializedSession) {
|
||||
return true;
|
||||
}
|
||||
|
||||
hasInitializedSession = true;
|
||||
|
||||
const hostname = location.hostname;
|
||||
return track(concat(
|
||||
new Uint8Array([ 0x01 ]),
|
||||
new Uint8Array([ 0x90, 0x63, 0x8A, 0xE7 ]),
|
||||
asUint32(hostname.length),
|
||||
encoder.encode(hostname)
|
||||
));
|
||||
}
|
||||
662
backends/web/api.rs
Normal file
662
backends/web/api.rs
Normal file
@@ -0,0 +1,662 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
ffi::CString,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec
|
||||
};
|
||||
use core::{
|
||||
ffi::{self, CStr},
|
||||
future::Future,
|
||||
pin::Pin
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use ort_sys::{stub::Error, *};
|
||||
|
||||
use crate::{
|
||||
binding,
|
||||
env::{Environment, TelemetryEvent},
|
||||
memory::{Allocator, MemoryInfo},
|
||||
session::{RunOptions, Session, SessionOptions},
|
||||
tensor::{SyncDirection, Tensor, TensorData, TypeInfo, create_buffer, onnx_to_dtype},
|
||||
util::value_to_string
|
||||
};
|
||||
|
||||
unsafe extern "system" fn CreateEnv(_log_severity_level: OrtLoggingLevel, _logid: *const ffi::c_char, out: *mut *mut OrtEnv) -> OrtStatusPtr {
|
||||
unsafe { out.write(Environment::new_sys()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateEnvWithCustomLogger(
|
||||
_logging_function: OrtLoggingFunction,
|
||||
_logger_param: *mut ffi::c_void,
|
||||
_log_severity_level: OrtLoggingLevel,
|
||||
_logid: *const ffi::c_char,
|
||||
out: *mut *mut OrtEnv
|
||||
) -> OrtStatusPtr {
|
||||
unsafe { out.write(Environment::new_sys()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn EnableTelemetryEvents(env: *const OrtEnv) -> OrtStatusPtr {
|
||||
let env = unsafe { Environment::cast_from_sys_mut(env.cast_mut()) };
|
||||
env.with_telemetry = true;
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn DisableTelemetryEvents(env: *const OrtEnv) -> OrtStatusPtr {
|
||||
let env = unsafe { Environment::cast_from_sys_mut(env.cast_mut()) };
|
||||
env.with_telemetry = false;
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe fn CreateSession(
|
||||
env: *const OrtEnv,
|
||||
model_path: &str,
|
||||
options: *const OrtSessionOptions,
|
||||
out: *mut *mut OrtSession
|
||||
) -> Pin<Box<dyn Future<Output = OrtStatusPtr>>> {
|
||||
let options = unsafe { &*options.cast::<SessionOptions>() };
|
||||
|
||||
let fut = Box::pin(async move {
|
||||
match Session::from_url(model_path, options).await {
|
||||
Ok(session) => {
|
||||
let ptr = (Box::leak(Box::new(session))) as *mut Session;
|
||||
unsafe { out.write(ptr.cast()) };
|
||||
|
||||
{
|
||||
let env = unsafe { Environment::cast_from_sys(env) };
|
||||
env.send_telemetry_event(TelemetryEvent::SessionInit);
|
||||
}
|
||||
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
Err(e) => e.into_sys()
|
||||
}
|
||||
}) as Pin<Box<dyn Future<Output = OrtStatusPtr>>>;
|
||||
unsafe { core::mem::transmute(fut) }
|
||||
}
|
||||
|
||||
unsafe fn CreateSessionFromArray(
|
||||
env: *const OrtEnv,
|
||||
model_data: &[u8],
|
||||
options: *const OrtSessionOptions,
|
||||
out: *mut *mut OrtSession
|
||||
) -> Pin<Box<dyn Future<Output = OrtStatusPtr>>> {
|
||||
let options = unsafe { &*options.cast::<SessionOptions>() };
|
||||
|
||||
let fut = Box::pin(async move {
|
||||
match Session::from_bytes(model_data, options).await {
|
||||
Ok(session) => {
|
||||
let ptr = (Box::leak(Box::new(session))) as *mut Session;
|
||||
unsafe { out.write(ptr.cast()) };
|
||||
|
||||
{
|
||||
let env = unsafe { Environment::cast_from_sys(env) };
|
||||
env.send_telemetry_event(TelemetryEvent::SessionInit);
|
||||
}
|
||||
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
Err(e) => e.into_sys()
|
||||
}
|
||||
}) as Pin<Box<dyn Future<Output = OrtStatusPtr>>>;
|
||||
unsafe { core::mem::transmute(fut) }
|
||||
}
|
||||
|
||||
unsafe extern "system" fn Run(
|
||||
_session: *mut OrtSession,
|
||||
_run_options: *const OrtRunOptions,
|
||||
_input_names: *const *const ::core::ffi::c_char,
|
||||
_inputs: *const *const OrtValue,
|
||||
_input_len: usize,
|
||||
_output_names: *const *const ::core::ffi::c_char,
|
||||
_output_names_len: usize,
|
||||
_output_ptrs: *mut *mut OrtValue
|
||||
) -> OrtStatusPtr {
|
||||
Error::new_sys(OrtErrorCode::ORT_FAIL, "Synchronous `Session::run` is not supported in ort-web; use `run_async()`.")
|
||||
}
|
||||
|
||||
unsafe fn RunAsync(
|
||||
session: *mut OrtSession,
|
||||
_run_options: *const OrtRunOptions,
|
||||
input_names: &[&str],
|
||||
inputs: &[*const OrtValue],
|
||||
output_names: &[&str],
|
||||
output_ptrs: &mut [*mut OrtValue]
|
||||
) -> Pin<Box<dyn Future<Output = OrtStatusPtr>>> {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
|
||||
let fut = Box::pin(async move {
|
||||
let inputs = input_names
|
||||
.iter()
|
||||
.zip(inputs)
|
||||
.map(|(&name, &input)| (name, unsafe { &*input.cast::<Tensor>() }))
|
||||
.collect::<Vec<(&str, &Tensor)>>();
|
||||
|
||||
match session.js.run(inputs.into_iter()).await {
|
||||
Ok(outputs) => {
|
||||
let output_names: Vec<String> = output_names.iter().map(|&name| name.to_string()).collect();
|
||||
let output_view = unsafe { core::slice::from_raw_parts_mut(output_ptrs.as_mut_ptr().cast::<*mut Tensor>(), output_ptrs.len()) };
|
||||
|
||||
for (name, mut tensor) in outputs {
|
||||
if let Some(index) = output_names
|
||||
.iter()
|
||||
.zip(output_view.iter_mut())
|
||||
.find_map(|(o_name, output)| if name == *o_name { Some(output) } else { None })
|
||||
{
|
||||
if !session.disable_sync {
|
||||
if let Err(e) = tensor.sync(SyncDirection::Rust).await {
|
||||
return Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to synchronize output '{name}': {e}"));
|
||||
}
|
||||
}
|
||||
|
||||
*index = Box::leak(Box::new(tensor));
|
||||
}
|
||||
}
|
||||
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to run session: {}", value_to_string(&e)))
|
||||
}
|
||||
}) as Pin<Box<dyn Future<Output = OrtStatusPtr>>>;
|
||||
unsafe { core::mem::transmute(fut) }
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateSessionOptions(options: *mut *mut OrtSessionOptions) -> OrtStatusPtr {
|
||||
unsafe { options.write((Box::leak(Box::new(SessionOptions::new())) as *mut SessionOptions).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionOptionsAppendExecutionProvider(
|
||||
options: *mut OrtSessionOptions,
|
||||
provider_name: *const ::core::ffi::c_char,
|
||||
provider_options_keys: *const *const ::core::ffi::c_char,
|
||||
provider_options_values: *const *const ::core::ffi::c_char,
|
||||
num_keys: usize
|
||||
) -> OrtStatusPtr {
|
||||
let options = unsafe { &mut *options.cast::<SessionOptions>() };
|
||||
let execution_providers = options.js.execution_providers.get_or_insert_default();
|
||||
|
||||
let Ok(options) = unsafe { core::slice::from_raw_parts(provider_options_keys, num_keys) }
|
||||
.iter()
|
||||
.zip(unsafe { core::slice::from_raw_parts(provider_options_values, num_keys) }.iter())
|
||||
.map(|(k, v)| Ok((unsafe { CStr::from_ptr(*k) }.to_str()?, unsafe { CStr::from_ptr(*v) }.to_str()?)))
|
||||
.collect::<Result<HashMap<&str, &str>, core::str::Utf8Error>>()
|
||||
else {
|
||||
return Error::new_sys(OrtErrorCode::ORT_FAIL, "EP options contains invalid UTF-8");
|
||||
};
|
||||
|
||||
let provider_name = unsafe { CStr::from_ptr(provider_name) };
|
||||
match provider_name.to_string_lossy().as_ref() {
|
||||
"WASM" => {
|
||||
execution_providers.push(binding::ExecutionProvider::WASM);
|
||||
}
|
||||
"WebGL" => {
|
||||
execution_providers.push(binding::ExecutionProvider::WebGL);
|
||||
}
|
||||
"WebGPU" => {
|
||||
execution_providers.push(binding::ExecutionProvider::WebGPU {
|
||||
preferred_layout: match options.get("ep.webgpuexecutionprovider.preferredLayout") {
|
||||
Some(&"NHWC") => Some(binding::WebGPUPreferredLayout::NHWC),
|
||||
Some(&"NCHW") => Some(binding::WebGPUPreferredLayout::NCHW),
|
||||
_ => None
|
||||
}
|
||||
});
|
||||
}
|
||||
"WebNN" => {
|
||||
execution_providers.push(binding::ExecutionProvider::WebNN {
|
||||
power_preference: match options.get("powerPreference") {
|
||||
Some(&"default") => Some(binding::WebNNPowerPreference::Default),
|
||||
Some(&"high-performance") => Some(binding::WebNNPowerPreference::HighPerformance),
|
||||
Some(&"low-power") => Some(binding::WebNNPowerPreference::LowPower),
|
||||
_ => None
|
||||
},
|
||||
device_type: match options.get("deviceType") {
|
||||
Some(&"cpu") => Some(binding::WebNNDeviceType::CPU),
|
||||
Some(&"npu") => Some(binding::WebNNDeviceType::NPU),
|
||||
Some(&"gpu") => Some(binding::WebNNDeviceType::GPU),
|
||||
_ => None
|
||||
},
|
||||
num_threads: options.get("numThreads").and_then(|c| c.parse().ok())
|
||||
});
|
||||
}
|
||||
x => return Error::new_sys(OrtErrorCode::ORT_NOT_IMPLEMENTED, format!("Provider '{x}' not supported"))
|
||||
}
|
||||
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CloneSessionOptions(in_options: *const OrtSessionOptions, out_options: *mut *mut OrtSessionOptions) -> OrtStatusPtr {
|
||||
let options = unsafe { &*in_options.cast::<SessionOptions>() };
|
||||
unsafe { out_options.write((Box::leak(Box::new(options.clone())) as *mut SessionOptions).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetInputCount(session: *const OrtSession, out: *mut usize) -> OrtStatusPtr {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
unsafe { out.write(session.js.input_len()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetOutputCount(session: *const OrtSession, out: *mut usize) -> OrtStatusPtr {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
unsafe { out.write(session.js.output_len()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetOverridableInitializerCount(_session: *const OrtSession, out: *mut usize) -> OrtStatusPtr {
|
||||
unsafe { out.write(0) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetInputTypeInfo(session: *const OrtSession, index: usize, type_info: *mut *mut OrtTypeInfo) -> OrtStatusPtr {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
let metadata = session.js.input_metadata().remove(index);
|
||||
if !metadata.is_tensor {
|
||||
return Error::new_sys(OrtErrorCode::ORT_FAIL, "non-tensor types are not currently supported");
|
||||
}
|
||||
|
||||
unsafe { type_info.write(TypeInfo::new_sys_from_value_metadata(&metadata)) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetOutputTypeInfo(session: *const OrtSession, index: usize, type_info: *mut *mut OrtTypeInfo) -> OrtStatusPtr {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
let metadata = session.js.output_metadata().remove(index);
|
||||
if !metadata.is_tensor {
|
||||
return Error::new_sys(OrtErrorCode::ORT_FAIL, "non-tensor types are not currently supported");
|
||||
}
|
||||
|
||||
unsafe { type_info.write(TypeInfo::new_sys_from_value_metadata(&metadata)) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetInputName(
|
||||
session: *const OrtSession,
|
||||
index: usize,
|
||||
_allocator: *mut OrtAllocator,
|
||||
value: *mut *mut ffi::c_char
|
||||
) -> OrtStatusPtr {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
let name = CString::new(&*session.js.input_names().remove(index)).unwrap();
|
||||
unsafe { value.write(name.into_raw()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SessionGetOutputName(
|
||||
session: *const OrtSession,
|
||||
index: usize,
|
||||
_allocator: *mut OrtAllocator,
|
||||
value: *mut *mut ffi::c_char
|
||||
) -> OrtStatusPtr {
|
||||
let session = unsafe { &*session.cast::<Session>() };
|
||||
let name = CString::new(&*session.js.output_names().remove(index)).unwrap();
|
||||
unsafe { value.write(name.into_raw()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateRunOptions(out: *mut *mut OrtRunOptions) -> OrtStatusPtr {
|
||||
unsafe { out.write((Box::leak(Box::new(RunOptions::new())) as *mut RunOptions).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateTensorAsOrtValue(
|
||||
_allocator: *mut OrtAllocator,
|
||||
shape: *const i64,
|
||||
shape_len: usize,
|
||||
type_: ONNXTensorElementDataType,
|
||||
out: *mut *mut OrtValue
|
||||
) -> OrtStatusPtr {
|
||||
let shape = unsafe { core::slice::from_raw_parts(shape, shape_len) }
|
||||
.iter()
|
||||
.map(|c| *c as i32)
|
||||
.collect::<Vec<_>>();
|
||||
let Some(dtype) = onnx_to_dtype(type_) else {
|
||||
return Error::new_sys(OrtErrorCode::ORT_FAIL, "unsupported dtype");
|
||||
};
|
||||
|
||||
match binding::Tensor::new_from_buffer(dtype, create_buffer(dtype, &shape), &shape) {
|
||||
Ok(tensor) => {
|
||||
unsafe { out.write((Box::leak(Box::new(Tensor::from_tensor(tensor))) as *mut Tensor).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to create tensor: {}", value_to_string(&e)))
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateTensorWithDataAsOrtValue(
|
||||
_info: *const OrtMemoryInfo,
|
||||
p_data: *mut ffi::c_void,
|
||||
p_data_len: usize,
|
||||
shape: *const i64,
|
||||
shape_len: usize,
|
||||
type_: ONNXTensorElementDataType,
|
||||
out: *mut *mut OrtValue
|
||||
) -> OrtStatusPtr {
|
||||
let shape = unsafe { core::slice::from_raw_parts(shape, shape_len) }
|
||||
.iter()
|
||||
.map(|c| *c as i32)
|
||||
.collect::<Vec<_>>();
|
||||
let Some(dtype) = onnx_to_dtype(type_) else {
|
||||
return Error::new_sys(OrtErrorCode::ORT_FAIL, "unsupported dtype");
|
||||
};
|
||||
|
||||
match unsafe { Tensor::from_ptr(dtype, p_data, p_data_len, &shape) } {
|
||||
Ok(tensor) => {
|
||||
unsafe { out.write((Box::leak(Box::new(tensor)) as *mut Tensor).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to create tensor: {}", value_to_string(&e)))
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "system" fn IsTensor(_value: *const OrtValue, out: *mut ffi::c_int) -> OrtStatusPtr {
|
||||
unsafe { out.write(1) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetTensorMutableData(value: *mut OrtValue, out: *mut *mut ffi::c_void) -> OrtStatusPtr {
|
||||
let tensor = unsafe { &mut *value.cast::<Tensor>() };
|
||||
match &mut tensor.data {
|
||||
TensorData::RustView { ptr, .. } => {
|
||||
unsafe { out.write(*ptr) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
TensorData::External { buffer } => {
|
||||
if let Some(buffer) = buffer {
|
||||
unsafe { out.write(buffer.as_mut_ptr().cast()) };
|
||||
OrtStatusPtr::default()
|
||||
} else {
|
||||
Error::new_sys(OrtErrorCode::ORT_FAIL, "External data is not synchronized; you should call `TensorExt::sync`.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CastTypeInfoToTensorInfo(type_info: *const OrtTypeInfo, out: *mut *const OrtTensorTypeAndShapeInfo) -> OrtStatusPtr {
|
||||
unsafe { out.write(type_info.cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetOnnxTypeFromTypeInfo(_type_info: *const OrtTypeInfo, out: *mut ONNXType) -> OrtStatusPtr {
|
||||
unsafe { out.write(ONNXType::ONNX_TYPE_TENSOR) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateTensorTypeAndShapeInfo(out: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr {
|
||||
unsafe { out.write(TypeInfo::new_sys(binding::DataType::Float32, Vec::new()).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SetTensorElementType(info: *mut OrtTensorTypeAndShapeInfo, type_: ONNXTensorElementDataType) -> OrtStatusPtr {
|
||||
let info = unsafe { &mut *info.cast::<TypeInfo>() };
|
||||
match onnx_to_dtype(type_) {
|
||||
Some(_) => {
|
||||
info.dtype = type_;
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
None => Error::new_sys(OrtErrorCode::ORT_FAIL, "Unsupported tensor data type")
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "system" fn SetDimensions(info: *mut OrtTensorTypeAndShapeInfo, dim_values: *const i64, dim_count: usize) -> OrtStatusPtr {
|
||||
let info = unsafe { &mut *info.cast::<TypeInfo>() };
|
||||
info.shape = unsafe { core::slice::from_raw_parts(dim_values.cast(), dim_count) }.to_vec();
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetTensorElementType(info: *const OrtTensorTypeAndShapeInfo, out: *mut ONNXTensorElementDataType) -> OrtStatusPtr {
|
||||
let info = unsafe { &*info.cast::<TypeInfo>() };
|
||||
unsafe { out.write(info.dtype) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetDimensionsCount(info: *const OrtTensorTypeAndShapeInfo, out: *mut usize) -> OrtStatusPtr {
|
||||
let info = unsafe { &*info.cast::<TypeInfo>() };
|
||||
unsafe { out.write(info.shape.len()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetDimensions(info: *const OrtTensorTypeAndShapeInfo, dim_values: *mut i64, dim_values_length: usize) -> OrtStatusPtr {
|
||||
let info = unsafe { &*info.cast::<TypeInfo>() };
|
||||
for (i, dim) in info.shape.iter().enumerate().take(dim_values_length) {
|
||||
unsafe { dim_values.add(i).write(*dim as _) };
|
||||
}
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetSymbolicDimensions(
|
||||
_info: *const OrtTensorTypeAndShapeInfo,
|
||||
dim_params: *mut *const ffi::c_char,
|
||||
dim_params_length: usize
|
||||
) -> OrtStatusPtr {
|
||||
for i in 0..dim_params_length {
|
||||
unsafe { dim_params.add(i).write(c"".as_ptr()) };
|
||||
}
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetTensorShapeElementCount(info: *const OrtTensorTypeAndShapeInfo, out: *mut usize) -> OrtStatusPtr {
|
||||
let info = unsafe { &*info.cast::<TypeInfo>() };
|
||||
let mut size = 1usize;
|
||||
for dim in &info.shape {
|
||||
size *= *dim as usize;
|
||||
}
|
||||
unsafe { out.write(size) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetTensorTypeAndShape(value: *const OrtValue, out: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr {
|
||||
let tensor = unsafe { &*value.cast::<Tensor>() };
|
||||
unsafe { out.write(TypeInfo::new_sys_from_tensor(tensor).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetTypeInfo(value: *const OrtValue, out: *mut *mut OrtTypeInfo) -> OrtStatusPtr {
|
||||
let tensor = unsafe { &*value.cast::<Tensor>() };
|
||||
unsafe { out.write(TypeInfo::new_sys_from_tensor(tensor)) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetValueType(_value: *const OrtValue, out: *mut ONNXType) -> OrtStatusPtr {
|
||||
unsafe { out.write(ONNXType::ONNX_TYPE_TENSOR) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateMemoryInfo(
|
||||
name: *const ffi::c_char,
|
||||
_type: OrtAllocatorType,
|
||||
_id: ffi::c_int,
|
||||
_mem_type: OrtMemType,
|
||||
out: *mut *mut OrtMemoryInfo
|
||||
) -> OrtStatusPtr {
|
||||
let device_name = unsafe { CStr::from_ptr(name) };
|
||||
match MemoryInfo::from_location(&*device_name.to_string_lossy()) {
|
||||
Some(inf) => {
|
||||
unsafe { *out = (Box::leak(Box::new(inf)) as *mut MemoryInfo).cast() };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
None => Error::new_sys(
|
||||
OrtErrorCode::ORT_FAIL,
|
||||
"Unsupported MemoryInfo type - only CPU tensors can be created this way. Tensors must be created from existing non-CPU buffers using `ort_web::TensorExt::from_*`."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateCpuMemoryInfo(_type: OrtAllocatorType, _mem_type: OrtMemType, out: *mut *mut OrtMemoryInfo) -> OrtStatusPtr {
|
||||
unsafe { *out = (Box::leak(Box::new(MemoryInfo { location: binding::DataLocation::Cpu })) as *mut MemoryInfo).cast() };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CompareMemoryInfo(info1: *const OrtMemoryInfo, info2: *const OrtMemoryInfo, out: *mut ffi::c_int) -> OrtStatusPtr {
|
||||
let info1 = unsafe { &*info1.cast::<MemoryInfo>() };
|
||||
let info2 = unsafe { &*info2.cast::<MemoryInfo>() };
|
||||
unsafe { out.write(if info1 == info2 { 0 } else { -1 }) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn MemoryInfoGetName(ptr: *const OrtMemoryInfo, out: *mut *const ffi::c_char) -> OrtStatusPtr {
|
||||
let info = unsafe { &*ptr.cast::<MemoryInfo>() };
|
||||
unsafe { out.write(info.location_exposed().unwrap_or(c"").as_ptr().cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn MemoryInfoGetId(_ptr: *const OrtMemoryInfo, out: *mut ffi::c_int) -> OrtStatusPtr {
|
||||
unsafe { out.write(0) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn MemoryInfoGetMemType(_ptr: *const OrtMemoryInfo, out: *mut OrtMemType) -> OrtStatusPtr {
|
||||
unsafe { out.write(OrtMemType::OrtMemTypeDefault) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn MemoryInfoGetType(_ptr: *const OrtMemoryInfo, out: *mut OrtAllocatorType) -> OrtStatusPtr {
|
||||
unsafe { out.write(OrtAllocatorType::OrtDeviceAllocator) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetAllocatorWithDefaultOptions(out: *mut *mut OrtAllocator) -> OrtStatusPtr {
|
||||
unsafe { out.write((&crate::memory::DEFAULT_CPU_ALLOCATOR as *const Allocator).cast_mut().cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseEnv(input: *mut OrtEnv) {
|
||||
drop(unsafe { Environment::consume_sys(input) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseStatus(input: *mut OrtStatus) {
|
||||
drop(unsafe { Error::consume_sys(input) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseMemoryInfo(input: *mut OrtMemoryInfo) {
|
||||
drop(unsafe { Box::<MemoryInfo>::from_raw(input.cast()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseSession(input: *mut OrtSession) {
|
||||
drop(unsafe { Box::<Session>::from_raw(input.cast()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseValue(input: *mut OrtValue) {
|
||||
drop(unsafe { Box::<Tensor>::from_raw(input.cast()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseRunOptions(input: *mut OrtRunOptions) {
|
||||
drop(unsafe { Box::<RunOptions>::from_raw(input.cast()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseTypeInfo(input: *mut OrtTypeInfo) {
|
||||
drop(unsafe { TypeInfo::consume_sys(input) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseTensorTypeAndShapeInfo(input: *mut OrtTensorTypeAndShapeInfo) {
|
||||
drop(unsafe { TypeInfo::consume_sys(input.cast()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseSessionOptions(input: *mut OrtSessionOptions) {
|
||||
drop(unsafe { Box::from_raw(input.cast::<SessionOptions>()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn CreateAllocator(_session: *const OrtSession, mem_info: *const OrtMemoryInfo, out: *mut *mut OrtAllocator) -> OrtStatusPtr {
|
||||
let mem_info = unsafe { &*mem_info.cast::<MemoryInfo>() };
|
||||
if mem_info.location != binding::DataLocation::Cpu {
|
||||
return Error::new_sys(OrtErrorCode::ORT_INVALID_ARGUMENT, "Only CPU allocators are supported.");
|
||||
}
|
||||
|
||||
unsafe { out.write((Box::leak(Box::new(Allocator::new())) as *mut Allocator).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ReleaseAllocator(input: *mut OrtAllocator) {
|
||||
drop(unsafe { Box::from_raw(input.cast::<Allocator>()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetTensorMemoryInfo(value: *const OrtValue, mem_info: *mut *const OrtMemoryInfo) -> OrtStatusPtr {
|
||||
let tensor = unsafe { &*value.cast::<Tensor>() };
|
||||
unsafe { mem_info.write((&tensor.memory_info as *const MemoryInfo).cast()) };
|
||||
OrtStatusPtr::default()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn MemoryInfoGetDeviceType(ptr: *const OrtMemoryInfo, out: *mut OrtMemoryInfoDeviceType) {
|
||||
let memory_info = unsafe { &*ptr.cast::<MemoryInfo>() };
|
||||
unsafe {
|
||||
out.write(match memory_info.location {
|
||||
binding::DataLocation::Cpu | binding::DataLocation::CpuPinned => OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
|
||||
_ => OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU
|
||||
})
|
||||
};
|
||||
}
|
||||
|
||||
unsafe extern "system" fn GetBuildInfoString() -> *const ffi::c_char {
|
||||
concat!("ORT Build Info: backend=ort-web, version=", env!("CARGO_PKG_VERSION"), ", with <3\0")
|
||||
.as_ptr()
|
||||
.cast()
|
||||
}
|
||||
|
||||
pub const fn api() -> OrtApi {
|
||||
OrtApi {
|
||||
CreateEnv,
|
||||
CreateEnvWithCustomLogger,
|
||||
EnableTelemetryEvents,
|
||||
DisableTelemetryEvents,
|
||||
CreateSession,
|
||||
CreateSessionFromArray,
|
||||
Run,
|
||||
RunAsync,
|
||||
CreateSessionOptions,
|
||||
CloneSessionOptions,
|
||||
SessionGetInputCount,
|
||||
SessionGetOutputCount,
|
||||
SessionGetOverridableInitializerCount,
|
||||
SessionGetInputTypeInfo,
|
||||
SessionGetOutputTypeInfo,
|
||||
SessionGetInputName,
|
||||
SessionGetOutputName,
|
||||
CreateTensorAsOrtValue,
|
||||
CreateTensorWithDataAsOrtValue,
|
||||
IsTensor,
|
||||
GetTensorMutableData,
|
||||
CastTypeInfoToTensorInfo,
|
||||
GetOnnxTypeFromTypeInfo,
|
||||
CreateTensorTypeAndShapeInfo,
|
||||
SetTensorElementType,
|
||||
SetDimensions,
|
||||
GetTensorElementType,
|
||||
GetDimensionsCount,
|
||||
GetDimensions,
|
||||
GetSymbolicDimensions,
|
||||
GetTensorShapeElementCount,
|
||||
GetTensorTypeAndShape,
|
||||
GetTypeInfo,
|
||||
GetValueType,
|
||||
CreateMemoryInfo,
|
||||
CreateCpuMemoryInfo,
|
||||
CompareMemoryInfo,
|
||||
MemoryInfoGetName,
|
||||
MemoryInfoGetId,
|
||||
MemoryInfoGetMemType,
|
||||
MemoryInfoGetType,
|
||||
GetAllocatorWithDefaultOptions,
|
||||
ReleaseEnv,
|
||||
ReleaseStatus,
|
||||
ReleaseMemoryInfo,
|
||||
ReleaseSession,
|
||||
ReleaseValue,
|
||||
ReleaseTypeInfo,
|
||||
ReleaseTensorTypeAndShapeInfo,
|
||||
ReleaseSessionOptions,
|
||||
CreateAllocator,
|
||||
ReleaseAllocator,
|
||||
GetTensorMemoryInfo,
|
||||
MemoryInfoGetDeviceType,
|
||||
GetBuildInfoString,
|
||||
CreateRunOptions,
|
||||
ReleaseRunOptions,
|
||||
SessionOptionsAppendExecutionProvider,
|
||||
..ort_sys::stub::api()
|
||||
}
|
||||
}
|
||||
41
backends/web/binding/mod.rs
Normal file
41
backends/web/binding/mod.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use js_sys::Boolean;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
mod session;
|
||||
pub use self::session::*;
|
||||
mod tensor;
|
||||
pub use self::tensor::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum DataType {
|
||||
Bool = "bool",
|
||||
Float16 = "float16",
|
||||
Float32 = "float32",
|
||||
Float64 = "float64",
|
||||
Int4 = "int4",
|
||||
Int8 = "int8",
|
||||
Int16 = "int16",
|
||||
Int32 = "int32",
|
||||
Int64 = "int64",
|
||||
Uint4 = "uint4",
|
||||
Uint8 = "uint8",
|
||||
Uint16 = "uint16",
|
||||
Uint32 = "uint32",
|
||||
Uint64 = "uint64",
|
||||
String = "string"
|
||||
}
|
||||
|
||||
#[wasm_bindgen(module = "/_loader.js")]
|
||||
extern "C" {
|
||||
#[wasm_bindgen(catch, js_name = "initRuntime")]
|
||||
pub async fn init_runtime(features: u8, dist: JsValue) -> Result<Boolean, JsValue>;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(module = "/_telemetry.js")]
|
||||
extern "C" {
|
||||
#[wasm_bindgen(catch, js_name = "trackSessionInit")]
|
||||
pub fn track_session_init() -> Result<Boolean, JsValue>;
|
||||
}
|
||||
224
backends/web/binding/session.rs
Normal file
224
backends/web/binding/session.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use js_sys::{JsString, Object, Reflect, Uint8Array};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::{binding::DataType, tensor::Tensor};
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ExecutionMode {
|
||||
Sequential,
|
||||
Parallel
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum GraphOptimizationLevel {
|
||||
Disabled,
|
||||
Basic,
|
||||
Layout,
|
||||
Extended,
|
||||
All
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WebNNDeviceType {
|
||||
CPU,
|
||||
GPU,
|
||||
NPU
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum WebNNPowerPreference {
|
||||
Default,
|
||||
HighPerformance,
|
||||
LowPower
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum WebGPUPreferredLayout {
|
||||
NHWC,
|
||||
NCHW
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
#[serde(tag = "name", rename_all = "lowercase")]
|
||||
pub enum ExecutionProvider {
|
||||
WASM,
|
||||
WebGL,
|
||||
#[serde(rename_all = "camelCase")]
|
||||
WebNN {
|
||||
device_type: Option<WebNNDeviceType>,
|
||||
num_threads: Option<u32>,
|
||||
power_preference: Option<WebNNPowerPreference>
|
||||
},
|
||||
#[serde(rename_all = "camelCase")]
|
||||
WebGPU {
|
||||
preferred_layout: Option<WebGPUPreferredLayout>
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionOptions {
|
||||
pub enable_cpu_mem_arena: Option<bool>,
|
||||
pub enable_graph_capture: Option<bool>,
|
||||
pub enable_mem_pattern: Option<bool>,
|
||||
pub enable_profiling: Option<bool>,
|
||||
pub execution_mode: Option<ExecutionMode>,
|
||||
pub execution_providers: Option<Vec<ExecutionProvider>>,
|
||||
pub extra: Option<HashMap<String, String>>,
|
||||
pub free_dimension_override: Option<HashMap<String, i32>>,
|
||||
pub graph_optimization_level: Option<GraphOptimizationLevel>,
|
||||
pub inter_op_num_threads: Option<u32>,
|
||||
pub intra_op_num_threads: Option<u32>,
|
||||
pub log_id: Option<String>,
|
||||
pub log_severity_level: Option<u8>,
|
||||
pub log_verbosity_level: Option<u16>
|
||||
}
|
||||
|
||||
impl SessionOptions {
|
||||
pub(crate) fn to_value(&self) -> Result<JsValue, serde_wasm_bindgen::Error> {
|
||||
self.serialize(&serde_wasm_bindgen::Serializer::new().serialize_maps_as_objects(true))
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
#[wasm_bindgen(js_namespace = ort)]
|
||||
pub type InferenceSession;
|
||||
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = InferenceSession, js_name = create)]
|
||||
async fn create_from_uri_raw(uri: &str, options: JsValue) -> Result<InferenceSession, JsValue>;
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = InferenceSession, js_name = create)]
|
||||
async fn create_from_bytes_raw(buffer: &Uint8Array, options: JsValue) -> Result<InferenceSession, JsValue>;
|
||||
|
||||
#[wasm_bindgen(catch, structural, method, js_name = startProfiling)]
|
||||
pub fn start_profiling(this: &InferenceSession) -> Result<(), JsValue>;
|
||||
#[wasm_bindgen(catch, structural, method, js_name = endProfiling)]
|
||||
pub fn end_profiling(this: &InferenceSession) -> Result<(), JsValue>;
|
||||
#[wasm_bindgen(catch, structural, method, js_name = release)]
|
||||
pub async fn release(this: &InferenceSession) -> Result<(), JsValue>;
|
||||
|
||||
#[wasm_bindgen(structural, method, getter, js_name = inputMetadata)]
|
||||
fn input_metadata_raw(this: &InferenceSession) -> Vec<JsValue>;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = outputMetadata)]
|
||||
fn output_metadata_raw(this: &InferenceSession) -> Vec<JsValue>;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = inputNames)]
|
||||
fn input_names_raw(this: &InferenceSession) -> Vec<JsString>;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = outputNames)]
|
||||
fn output_names_raw(this: &InferenceSession) -> Vec<JsString>;
|
||||
|
||||
#[wasm_bindgen(catch, structural, method, js_name = run)]
|
||||
async fn run_raw(this: &InferenceSession, feeds: JsValue) -> Result<JsValue, JsValue>;
|
||||
#[wasm_bindgen(catch, structural, method, js_name = run)]
|
||||
async fn run_with_fetches_raw(this: &InferenceSession, feeds: JsValue, fetches: JsValue) -> Result<JsValue, JsValue>;
|
||||
}
|
||||
|
||||
impl InferenceSession {
|
||||
pub async fn create_from_uri(uri: &str, options: &SessionOptions) -> Result<InferenceSession, JsValue> {
|
||||
InferenceSession::create_from_uri_raw(uri, options.to_value()?).await
|
||||
}
|
||||
pub async fn create_from_bytes(buffer: &Uint8Array, options: &SessionOptions) -> Result<InferenceSession, JsValue> {
|
||||
InferenceSession::create_from_bytes_raw(buffer, options.to_value()?).await
|
||||
}
|
||||
|
||||
pub fn input_names(&self) -> Vec<String> {
|
||||
self.input_names_raw().into_iter().map(String::from).collect()
|
||||
}
|
||||
pub fn output_names(&self) -> Vec<String> {
|
||||
self.output_names_raw().into_iter().map(String::from).collect()
|
||||
}
|
||||
|
||||
pub fn input_len(&self) -> usize {
|
||||
self.input_names_raw().len()
|
||||
}
|
||||
pub fn output_len(&self) -> usize {
|
||||
self.output_names_raw().len()
|
||||
}
|
||||
|
||||
pub fn input_metadata(&self) -> Vec<ValueMetadata> {
|
||||
self.input_metadata_raw()
|
||||
.into_iter()
|
||||
.map(|x| serde_wasm_bindgen::from_value(x))
|
||||
.collect::<Result<Vec<_>, serde_wasm_bindgen::Error>>()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn output_metadata(&self) -> Vec<ValueMetadata> {
|
||||
self.output_metadata_raw()
|
||||
.into_iter()
|
||||
.map(|x| serde_wasm_bindgen::from_value(x))
|
||||
.collect::<Result<Vec<_>, serde_wasm_bindgen::Error>>()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn run(&self, feeds: impl Iterator<Item = (&str, &Tensor)>) -> Result<Vec<(String, Tensor)>, JsValue> {
|
||||
let feeds_value = Object::new();
|
||||
for (name, tensor) in feeds {
|
||||
Reflect::set(&feeds_value, &JsValue::from_str(name), &tensor.js)?;
|
||||
}
|
||||
Self::to_outputs(self.run_raw(feeds_value.into()).await?)
|
||||
}
|
||||
|
||||
pub async fn run_with_fetches(
|
||||
&self,
|
||||
feeds: impl Iterator<Item = (&str, &Tensor)>,
|
||||
fetches: impl Iterator<Item = (&str, Option<&Tensor>)>
|
||||
) -> Result<Vec<(String, Tensor)>, JsValue> {
|
||||
let feeds_value = Object::new();
|
||||
for (name, tensor) in feeds {
|
||||
Reflect::set(&feeds_value, &JsValue::from_str(name), &tensor.js)?;
|
||||
}
|
||||
let fetches_value = Object::new();
|
||||
for (name, tensor) in fetches {
|
||||
let null = JsValue::null();
|
||||
Reflect::set(
|
||||
&fetches_value,
|
||||
&JsValue::from_str(name),
|
||||
match tensor {
|
||||
Some(tensor) => &tensor.js,
|
||||
None => &null
|
||||
}
|
||||
)?;
|
||||
}
|
||||
Self::to_outputs(self.run_with_fetches_raw(feeds_value.into(), fetches_value.into()).await?)
|
||||
}
|
||||
|
||||
fn to_outputs(value: JsValue) -> Result<Vec<(String, Tensor)>, JsValue> {
|
||||
Ok(Reflect::own_keys(&value)?
|
||||
.to_vec()
|
||||
.into_iter()
|
||||
.filter_map(|c| {
|
||||
c.dyn_ref::<JsString>().map(String::from).and_then(|k| {
|
||||
Reflect::get(&value, &c)
|
||||
.map(super::Tensor::unchecked_from_js)
|
||||
.ok()
|
||||
.map(|v| (k, Tensor::from_tensor(v)))
|
||||
})
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum ShapeElement {
|
||||
Named(String),
|
||||
Value(i32)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ValueMetadata {
|
||||
pub is_tensor: bool,
|
||||
pub name: String,
|
||||
pub shape: Option<Vec<ShapeElement>>,
|
||||
pub r#type: Option<DataType>
|
||||
}
|
||||
235
backends/web/binding/tensor.rs
Normal file
235
backends/web/binding/tensor.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use js_sys::JsString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::{HtmlImageElement, ImageBitmap, ImageData, WebGlTexture};
|
||||
|
||||
use crate::binding::DataType;
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum ImageFormat {
|
||||
Rgb,
|
||||
Rgba,
|
||||
Bgr,
|
||||
Rbg
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum ImageTensorLayout {
|
||||
Nhwc,
|
||||
Nchw
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ImageDataType {
|
||||
Float32,
|
||||
Uint8
|
||||
}
|
||||
|
||||
impl Into<DataType> for ImageDataType {
|
||||
fn into(self) -> DataType {
|
||||
match self {
|
||||
Self::Float32 => DataType::Float32,
|
||||
Self::Uint8 => DataType::Uint8
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ImageNormOption {
|
||||
Splat(f32),
|
||||
PerChannel([f32; 3]),
|
||||
PerChannelWithAlpha([f32; 4])
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ImageNorm {
|
||||
pub bias: Option<ImageNormOption>,
|
||||
pub mean: Option<ImageNormOption>
|
||||
}
|
||||
|
||||
impl ImageNorm {
|
||||
pub const fn imagenet(format: ImageFormat) -> ImageNorm {
|
||||
const RGB_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||
const RGB_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||
ImageNorm {
|
||||
mean: Some(match format {
|
||||
ImageFormat::Rgb => ImageNormOption::PerChannel(RGB_MEAN),
|
||||
ImageFormat::Rgba => ImageNormOption::PerChannelWithAlpha([RGB_MEAN[0], RGB_MEAN[1], RGB_MEAN[2], 0.5]),
|
||||
ImageFormat::Bgr => ImageNormOption::PerChannel([RGB_MEAN[2], RGB_MEAN[1], RGB_MEAN[0]]),
|
||||
ImageFormat::Rbg => ImageNormOption::PerChannel([RGB_MEAN[0], RGB_MEAN[2], RGB_MEAN[1]])
|
||||
}),
|
||||
bias: Some(match format {
|
||||
ImageFormat::Rgb => ImageNormOption::PerChannel(RGB_STD),
|
||||
ImageFormat::Rgba => ImageNormOption::PerChannelWithAlpha([RGB_STD[0], RGB_STD[1], RGB_STD[2], 0.5]),
|
||||
ImageFormat::Bgr => ImageNormOption::PerChannel([RGB_STD[2], RGB_STD[1], RGB_STD[0]]),
|
||||
ImageFormat::Rbg => ImageNormOption::PerChannel([RGB_STD[0], RGB_STD[2], RGB_STD[1]])
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TensorFromImageOptions {
|
||||
pub data_type: Option<ImageDataType>,
|
||||
pub norm: Option<ImageNorm>,
|
||||
pub resized_height: Option<u32>,
|
||||
pub resized_width: Option<u32>,
|
||||
pub tensor_format: Option<ImageFormat>,
|
||||
pub tensor_layout: Option<ImageTensorLayout>
|
||||
}
|
||||
|
||||
impl TensorFromImageOptions {
|
||||
pub(crate) fn to_value(&self) -> Result<JsValue, serde_wasm_bindgen::Error> {
|
||||
serde_wasm_bindgen::to_value(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TensorFromUrlOptions {
|
||||
#[serde(flatten)]
|
||||
base: TensorFromImageOptions,
|
||||
pub width: Option<u32>,
|
||||
pub height: Option<u32>
|
||||
}
|
||||
|
||||
impl TensorFromUrlOptions {
|
||||
pub(crate) fn to_value(&self) -> Result<JsValue, serde_wasm_bindgen::Error> {
|
||||
serde_wasm_bindgen::to_value(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct DisposeFunction(#[serde(with = "serde_wasm_bindgen::preserve")] JsValue);
|
||||
|
||||
impl<T> From<T> for DisposeFunction
|
||||
where
|
||||
T: FnOnce() + 'static
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
DisposeFunction(Closure::once_into_js(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct DownloadFunction(#[serde(with = "serde_wasm_bindgen::preserve")] JsValue);
|
||||
|
||||
impl<T, F, E> From<T> for DownloadFunction
|
||||
where
|
||||
T: FnOnce() -> F + 'static,
|
||||
F: Future<Output = Result<JsValue, E>> + 'static,
|
||||
E: core::error::Error
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
DownloadFunction(Closure::once_into_js(move || {
|
||||
wasm_bindgen_futures::future_to_promise(async move {
|
||||
match value().await {
|
||||
Ok(value) => Ok(value),
|
||||
Err(e) => Err(JsString::from(e.to_string()).into())
|
||||
}
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TensorFromTextureOptions {
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub format: Option<ImageFormat>,
|
||||
pub dispose: Option<DisposeFunction>,
|
||||
pub download: Option<DownloadFunction>
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[derive(Deserialize, Serialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum DataLocation {
|
||||
None = "none", // indicates tensor is disposed
|
||||
Cpu = "cpu",
|
||||
CpuPinned = "cpu-pinned", // what is *pinned* in WASM?
|
||||
Texture = "texture",
|
||||
GpuBuffer = "gpu-buffer",
|
||||
MlTensor = "ml-tensor"
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
#[wasm_bindgen(js_namespace = ort)]
|
||||
pub type Tensor;
|
||||
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)]
|
||||
async fn from_image_data_raw(image_data: &ImageData, options: JsValue) -> Result<Tensor, JsValue>;
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)]
|
||||
async fn from_image_element_raw(element: &HtmlImageElement, options: JsValue) -> Result<Tensor, JsValue>;
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)]
|
||||
async fn from_image_bitmap_raw(bitmap: &ImageBitmap, options: JsValue) -> Result<Tensor, JsValue>;
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromImage)]
|
||||
async fn from_image_url_raw(url: &str, options: JsValue) -> Result<Tensor, JsValue>;
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromTexture)]
|
||||
fn from_texture(texture: &WebGlTexture, options: JsValue) -> Result<Tensor, JsValue>;
|
||||
#[cfg(web_sys_unstable_apis)]
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromGpuBuffer)]
|
||||
fn from_gpu_buffer(buffer: &web_sys::GpuBuffer, options: JsValue) -> Result<Tensor, JsValue>;
|
||||
#[wasm_bindgen(catch, js_namespace = ort, static_method_of = Tensor, js_name = fromPinnedBuffer)]
|
||||
fn from_pinned_buffer(dtype: DataType, buffer: JsValue, dims: JsValue) -> Result<Tensor, JsValue>;
|
||||
|
||||
#[wasm_bindgen(constructor, catch, js_namespace = ort, js_class = Tensor)]
|
||||
fn new_from_buffer_raw(dtype: DataType, buffer: JsValue, dims: JsValue) -> Result<Tensor, JsValue>;
|
||||
|
||||
#[wasm_bindgen(structural, catch, method, getter, js_name = data)]
|
||||
pub fn data(this: &Tensor) -> Result<JsValue, JsValue>;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = location)]
|
||||
pub fn location(this: &Tensor) -> DataLocation;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = type)]
|
||||
pub fn dtype(this: &Tensor) -> DataType;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = size)]
|
||||
pub fn size(this: &Tensor) -> usize;
|
||||
#[wasm_bindgen(structural, method, getter, js_name = dims)]
|
||||
pub fn dims(this: &Tensor) -> Vec<i32>;
|
||||
|
||||
#[wasm_bindgen(structural, catch, method, js_name = getData)]
|
||||
pub async fn get_data(this: &Tensor) -> Result<JsValue, JsValue>;
|
||||
|
||||
#[wasm_bindgen(structural, catch, method, js_name = dispose)]
|
||||
pub fn dispose(this: &Tensor) -> Result<(), JsValue>;
|
||||
#[wasm_bindgen(structural, catch, method, js_name = reshape)]
|
||||
fn reshape(this: &Tensor, dims: JsValue) -> Result<Tensor, JsValue>;
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub async fn from_image_data(image_data: &ImageData, options: &TensorFromImageOptions) -> Result<Tensor, JsValue> {
|
||||
Self::from_image_data_raw(image_data, options.to_value()?).await
|
||||
}
|
||||
|
||||
pub async fn from_image_element(element: &HtmlImageElement, options: &TensorFromImageOptions) -> Result<Tensor, JsValue> {
|
||||
Self::from_image_element_raw(element, options.to_value()?).await
|
||||
}
|
||||
|
||||
pub async fn from_image_bitmap(bitmap: &ImageBitmap, options: &TensorFromImageOptions) -> Result<Tensor, JsValue> {
|
||||
Self::from_image_bitmap_raw(bitmap, options.to_value()?).await
|
||||
}
|
||||
|
||||
pub async fn from_image_url(url: &str, options: &TensorFromUrlOptions) -> Result<Tensor, JsValue> {
|
||||
Self::from_image_url_raw(url, options.to_value()?).await
|
||||
}
|
||||
|
||||
pub fn new_from_buffer(dtype: DataType, buffer: JsValue, dims: &[i32]) -> Result<Tensor, JsValue> {
|
||||
Self::new_from_buffer_raw(dtype, buffer, convert_dims(dims))
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_dims(dims: &[i32]) -> JsValue {
|
||||
dims.iter().map(|d| js_sys::Number::from(*d)).collect::<js_sys::Array>().into()
|
||||
}
|
||||
40
backends/web/env.rs
Normal file
40
backends/web/env.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use alloc::boxed::Box;
|
||||
|
||||
use crate::binding;
|
||||
|
||||
pub(crate) struct Environment {
|
||||
pub with_telemetry: bool
|
||||
}
|
||||
|
||||
impl Environment {
|
||||
pub fn new_sys() -> *mut ort_sys::OrtEnv {
|
||||
(Box::leak(Box::new(Self { with_telemetry: true })) as *mut Environment).cast()
|
||||
}
|
||||
|
||||
pub unsafe fn cast_from_sys<'e>(ptr: *const ort_sys::OrtEnv) -> &'e Environment {
|
||||
unsafe { &*ptr.cast::<Environment>() }
|
||||
}
|
||||
|
||||
pub unsafe fn cast_from_sys_mut<'e>(ptr: *mut ort_sys::OrtEnv) -> &'e mut Environment {
|
||||
unsafe { &mut *ptr.cast::<Environment>() }
|
||||
}
|
||||
|
||||
pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box<Environment> {
|
||||
unsafe { Box::from_raw(ptr.cast::<Environment>()) }
|
||||
}
|
||||
|
||||
pub fn send_telemetry_event(&self, event: TelemetryEvent) {
|
||||
if !self.with_telemetry {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = match event {
|
||||
TelemetryEvent::SessionInit => binding::track_session_init()
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TelemetryEvent {
|
||||
SessionInit
|
||||
}
|
||||
279
backends/web/lib.rs
Normal file
279
backends/web/lib.rs
Normal file
@@ -0,0 +1,279 @@
|
||||
//! `ort-web` is an [`ort`] backend that enables the usage of ONNX Runtime in the web.
|
||||
//!
|
||||
//! # Usage
|
||||
//! ## CORS
|
||||
//! `ort-web` dynamically fetches the required scripts & WASM binary at runtime. By default, it will fetch the build
|
||||
//! from the `cdn.pyke.io` domain, so make sure it is accessible via CORS if you have that configured.
|
||||
//!
|
||||
//! You can also use a self-hosted build with [`Dist`]; see the [`api`](fn@api) function for an example. The scripts &
|
||||
//! binary can be acquired from the `dist` folder of the [`onnxruntime-web` npm package](https://npmjs.com/package/onnxruntime-web).
|
||||
//!
|
||||
//! ### Telemetry
|
||||
//! `ort-web` collects telemetry data by default and sends it to `signal.pyke.io`. This telemetry data helps us
|
||||
//! understand how `ort-web` is being used so we can improve it. Zero PII is collected; you can see what is sent in
|
||||
//! `_telemetry.js`. If you wish to contribute telemetry data, please allowlist `signal.pyke.io`; otherwise, it can be
|
||||
//! disabled via [`EnvironmentBuilder::with_telemetry`](ort::environment::EnvironmentBuilder::with_telemetry).
|
||||
//!
|
||||
//! ## Initialization
|
||||
//! `ort` must have the `alternative-backend` feature enabled, as this enables the usage of [`ort::set_api`].
|
||||
//!
|
||||
//! You can choose which build of ONNX Runtime to fetch by choosing any combination of these 3 feature flags:
|
||||
//! [`FEATURE_WEBGL`], [`FEATURE_WEBGPU`], [`FEATURE_WEBNN`]. These enable the usage of the [WebGL][ort::ep::WebGL],
|
||||
//! [WebGPU][ort::ep::WebGPU], and [WebNN][ort::ep::WebNN] EPs respectively. You can `|` features together to enable
|
||||
//! multiple at once:
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use ort_web::{FEATURE_WEBGL, FEATURE_WEBGPU};
|
||||
//! ort::set_api(ort_web::api(FEATURE_WEBGL | FEATURE_WEBGPU).await?);
|
||||
//! ```
|
||||
//!
|
||||
//! You'll still need to configure the EPs on a per-session basis later like you would normally, but this allows you to
|
||||
//! e.g. only fetch the CPU build if the user doesn't have hardware acceleration.
|
||||
//!
|
||||
//! ## Session creation
|
||||
//! Sessions can only be created from a URL, or indirectly from memory - that means no
|
||||
//! `SessionBuilder::commit_from_memory_directly` for `.ort` format models, and no `SessionBuilder::commit_from_file`.
|
||||
//!
|
||||
//! The remaining commit functions - `SessionBuilder::commit_from_url` and `SessionBuilder::commit_from_memory` are
|
||||
//! marked `async` and need to be `await`ed. `commit_from_url` is always available when targeting WASM and does not
|
||||
//! require the `fetch-models` feature flag to be enabled for `ort`.
|
||||
//!
|
||||
//! ## Inference
|
||||
//! Only `Session::run_async` is supported; `Session::run` will always throw an error.
|
||||
//!
|
||||
//! Inference outputs are not synchronized by default (see the next section). If you need access to the data of all
|
||||
//! session outputs from Rust, the [`sync_outputs`] function can be used to sync them all at once.
|
||||
//!
|
||||
//! ## Synchronization
|
||||
//! ONNX Runtime is loaded as a separate WASM module, and `ort-web` acts as an intermediary between the two. There is no
|
||||
//! mechanism in WASM for two modules to share memory, so tensors often need to be 'synchronized' when one side needs to
|
||||
//! see data from the other.
|
||||
//!
|
||||
//! [`Tensor::new`](ort::value::Tensor::new) should never be used for creating inputs, as they start out allocated on
|
||||
//! the ONNX Runtime side, thus requiring a sync (of empty data) to Rust before it can be written to. Prefer instead
|
||||
//! [`Tensor::from_array`](ort::value::Tensor::from_array)/
|
||||
//! [`TensorRef::from_array_view`](ort::value::TensorRef::from_array_view), as tensors created this way never require
|
||||
//! synchronization.
|
||||
//!
|
||||
//! As previously stated, session outputs are **not** synchronized. If you wish to use their data in Rust, you must
|
||||
//! either sync all outputs at once with [`sync_outputs`], or sync each tensor at a time (if you only use a few
|
||||
//! outputs):
|
||||
//! ```ignore
|
||||
//! use ort_web::{TensorExt, SyncDirection};
|
||||
//!
|
||||
//! let mut outputs = session.run_async(ort::inputs![...]).await?;
|
||||
//!
|
||||
//! let mut bounding_boxes = outputs.remove("bounding_boxes").unwrap();
|
||||
//! bounding_boxes.sync(SyncDirection::Rust).await?;
|
||||
//!
|
||||
//! // now we can use the data
|
||||
//! let data = bounding_boxes.try_extract_tensor::<f32>()?;
|
||||
//! ```
|
||||
//!
|
||||
//! Once a session output is `sync`ed, that tensor becomes backed by a Rust buffer. Updates to the tensor's data from
|
||||
//! the Rust side will not reflect in ONNX Runtime until the tensor is `sync`ed with `SyncDirection::Runtime`. Likewise,
|
||||
//! updates to the tensor's data from ONNX Runtime won't reflect in Rust until Rust syncs that tensor with
|
||||
//! `SyncDirection::Rust`. You don't have to worry about this behavior if you only ever *read* from session outputs,
|
||||
//! though.
|
||||
//!
|
||||
//! ## Limitations
|
||||
//! - [`OutputSelector`](ort::session::run_options::OutputSelector) is not currently implemented.
|
||||
//! - [`IoBinding`](ort::io_binding) is not supported by ONNX Runtime on the web.
|
||||
|
||||
#![deny(clippy::panic, clippy::panicking_unwrap)]
|
||||
#![warn(clippy::std_instead_of_alloc, clippy::std_instead_of_core)]
|
||||
|
||||
extern crate alloc;
|
||||
extern crate core;
|
||||
|
||||
use alloc::string::String;
|
||||
use core::fmt;
|
||||
|
||||
use serde::Serialize;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::util::value_to_string;
|
||||
|
||||
mod api;
|
||||
mod binding;
|
||||
mod env;
|
||||
mod memory;
|
||||
mod session;
|
||||
mod tensor;
|
||||
mod util;
|
||||
#[macro_use]
|
||||
pub(crate) mod private;
|
||||
|
||||
pub use self::{
|
||||
session::sync_outputs,
|
||||
tensor::{SyncDirection, ValueExt}
|
||||
};
|
||||
|
||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Error {
|
||||
msg: String
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub(crate) fn new(msg: impl Into<String>) -> Self {
|
||||
Self { msg: msg.into() }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JsValue> for Error {
|
||||
fn from(value: JsValue) -> Self {
|
||||
Self::new(value_to_string(&value))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.msg.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::error::Error for Error {}
|
||||
|
||||
/// Do not enable any execution provider features (CPU-only).
|
||||
pub const FEATURE_NONE: u8 = 0;
|
||||
/// Enable the WebGL execution provider for hardware acceleration.
|
||||
///
|
||||
/// See: <https://caniuse.com/webgl2>
|
||||
pub const FEATURE_WEBGL: u8 = 1 << 0;
|
||||
/// Enable the WebGPU execution provider for hardware acceleration.
|
||||
///
|
||||
/// See: <https://caniuse.com/webgpu>
|
||||
pub const FEATURE_WEBGPU: u8 = 1 << 1;
|
||||
/// Enable the WebNN execution provider for hardware acceleration.
|
||||
///
|
||||
/// See: <https://webmachinelearning.github.io/webnn-status/>
|
||||
pub const FEATURE_WEBNN: u8 = FEATURE_WEBGPU;
|
||||
|
||||
/// Loads an `ort`-compatible ONNX Runtime API from `config`.
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - The requested feature set is not supported by `ort-web`.
|
||||
/// - The JavaScript/WASM modules fail to load.
|
||||
///
|
||||
/// `config` can be a feature set, in which case the default pyke-hosted builds will be used:
|
||||
/// ```no_run
|
||||
/// use ort::session::Session;
|
||||
/// use ort_web::{FEATURE_WEBGL, FEATURE_WEBGPU};
|
||||
///
|
||||
/// async fn init_model() -> anyhow::Result<Session> {
|
||||
/// // This must be called at least once before using any `ort` API.
|
||||
/// ort::set_api(ort_web::api(FEATURE_WEBGL | FEATURE_WEBGPU).await?);
|
||||
///
|
||||
/// let session = Session::builder()?.commit_from_url("https://...").await?;
|
||||
/// Ok(session)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// You can also use [`Dist`] to self-host the build:
|
||||
/// ```no_run
|
||||
/// use ort::session::Session;
|
||||
/// use ort_web::Dist;
|
||||
///
|
||||
/// async fn init_model() -> anyhow::Result<Session> {
|
||||
/// let dist = Dist::new("https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/")
|
||||
/// // load the WebGPU build
|
||||
/// .with_script_name("ort.webgpu.min.js");
|
||||
/// ort::set_api(ort_web::api(dist).await?);
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn api<L: Loadable>(config: L) -> Result<ort_sys::OrtApi> {
|
||||
let (features, dist) = config.into_features_and_dist()?;
|
||||
binding::init_runtime(features, dist).await?;
|
||||
|
||||
Ok(self::api::api())
|
||||
}
|
||||
|
||||
pub trait Loadable {
|
||||
#[doc(hidden)]
|
||||
fn into_features_and_dist(self) -> Result<(u8, JsValue)>;
|
||||
}
|
||||
|
||||
impl Loadable for u8 {
|
||||
fn into_features_and_dist(self) -> Result<(u8, JsValue)> {
|
||||
Ok((self, JsValue::null()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Loadable for Dist {
|
||||
fn into_features_and_dist(self) -> Result<(u8, JsValue)> {
|
||||
Ok((0, serde_wasm_bindgen::to_value(&self).map_err(|e| Error::new(e.to_string()))?))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Integrities {
|
||||
main: Option<String>,
|
||||
wrapper: Option<String>,
|
||||
binary: Option<String>
|
||||
}
|
||||
|
||||
impl Integrities {
|
||||
/// Set the SHA-384 SRI hash for the main (entrypoint) script.
|
||||
pub fn set_main(&mut self, hash: impl Into<String>) {
|
||||
self.main = Some(hash.into());
|
||||
}
|
||||
|
||||
/// Set the SHA-384 SRI hash for the Emscripten wrapper script.
|
||||
pub fn set_wrapper(&mut self, hash: impl Into<String>) {
|
||||
self.wrapper = Some(hash.into());
|
||||
}
|
||||
|
||||
/// Set the SHA-384 SRI hash for the WASM binary.
|
||||
pub fn set_binary(&mut self, hash: impl Into<String>) {
|
||||
self.binary = Some(hash.into());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Dist {
|
||||
base_url: String,
|
||||
script_name: String,
|
||||
binary_name: Option<String>,
|
||||
wrapper_name: Option<String>,
|
||||
integrities: Integrities
|
||||
}
|
||||
|
||||
impl Dist {
|
||||
pub fn new(base_url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
script_name: "ort.wasm.min.js".to_string(),
|
||||
binary_name: None,
|
||||
wrapper_name: None,
|
||||
integrities: Integrities::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the name of the entrypoint script file; defaults to `"ort.wasm.min.js"`.
|
||||
pub fn with_script_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.script_name = name.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Enables preloading the WASM binary loaded by the entrypoint script.
|
||||
pub fn with_binary_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.binary_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the name of the Emscripten wrapper script preloaded along with the WASM binary, if preloading is
|
||||
/// enabled. Defaults to the binary name with the `.wasm` extension replaced with `.mjs`.
|
||||
pub fn with_wrapper_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.wrapper_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Modify Subresource Integrity (SRI) hashes.
|
||||
pub fn integrities(&mut self) -> &mut Integrities {
|
||||
&mut self.integrities
|
||||
}
|
||||
}
|
||||
71
backends/web/memory.rs
Normal file
71
backends/web/memory.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use alloc::ffi::CString;
|
||||
use core::{
|
||||
ffi::{CStr, c_void},
|
||||
ptr
|
||||
};
|
||||
|
||||
use crate::binding;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Allocator {
|
||||
_sys_api: ort_sys::OrtAllocator
|
||||
}
|
||||
|
||||
impl Allocator {
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
_sys_api: ort_sys::OrtAllocator {
|
||||
version: ort_sys::ORT_API_VERSION,
|
||||
Alloc: Some(sys_allocator_alloc),
|
||||
Free: Some(sys_allocator_free),
|
||||
Info: Some(sys_allocator_info),
|
||||
Reserve: Some(sys_allocator_reserve)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub static DEFAULT_CPU_ALLOCATOR: Allocator = Allocator::new();
|
||||
|
||||
unsafe extern "system" fn sys_allocator_alloc(_this: *mut ort_sys::OrtAllocator, _size: usize) -> *mut c_void {
|
||||
ptr::null_mut()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn sys_allocator_free(_this: *mut ort_sys::OrtAllocator, p: *mut c_void) {
|
||||
drop(unsafe { CString::from_raw(p.cast()) });
|
||||
}
|
||||
|
||||
unsafe extern "system" fn sys_allocator_info(this_: *const ort_sys::OrtAllocator) -> *const ort_sys::OrtMemoryInfo {
|
||||
let _allocator = unsafe { &*this_.cast::<Allocator>() };
|
||||
ptr::dangling()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn sys_allocator_reserve(_this: *const ort_sys::OrtAllocator, _size: usize) -> *mut c_void {
|
||||
ptr::null_mut()
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct MemoryInfo {
|
||||
pub location: binding::DataLocation
|
||||
}
|
||||
|
||||
impl MemoryInfo {
|
||||
pub fn location_exposed(&self) -> Option<&'static CStr> {
|
||||
match self.location {
|
||||
binding::DataLocation::Cpu | binding::DataLocation::CpuPinned => Some(c"Cpu"),
|
||||
binding::DataLocation::Texture => Some(c"WebGL"),
|
||||
binding::DataLocation::GpuBuffer => Some(c"WebGPU_Buffer"),
|
||||
binding::DataLocation::MlTensor => Some(c"WebNN"),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_location(location: &str) -> Option<Self> {
|
||||
match location {
|
||||
"Cpu" => Some(Self {
|
||||
location: binding::DataLocation::CpuPinned
|
||||
}),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
17
backends/web/private.rs
Normal file
17
backends/web/private.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
pub struct PrivateTraitMarker;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! private_trait {
|
||||
() => {
|
||||
#[doc(hidden)]
|
||||
fn _private() -> crate::private::PrivateTraitMarker;
|
||||
};
|
||||
}
|
||||
#[macro_export]
|
||||
macro_rules! private_impl {
|
||||
() => {
|
||||
fn _private() -> crate::private::PrivateTraitMarker {
|
||||
crate::private::PrivateTraitMarker
|
||||
}
|
||||
};
|
||||
}
|
||||
85
backends/web/session.rs
Normal file
85
backends/web/session.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use js_sys::Uint8Array;
|
||||
use ort::session::SessionOutputs;
|
||||
use ort_sys::{OrtErrorCode, stub::Error};
|
||||
|
||||
use crate::{
|
||||
binding,
|
||||
tensor::{SyncDirection, ValueExt},
|
||||
util::value_to_string
|
||||
};
|
||||
|
||||
pub const SESSION_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x01];
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Session {
|
||||
sentinel: [u8; 4],
|
||||
pub js: binding::InferenceSession,
|
||||
pub disable_sync: bool
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn from_url(uri: &str, options: &SessionOptions) -> Result<Self, Error> {
|
||||
Ok(Session {
|
||||
sentinel: SESSION_SENTINEL,
|
||||
js: binding::InferenceSession::create_from_uri(uri, &options.js)
|
||||
.await
|
||||
.map_err(|e| Error::new(OrtErrorCode::ORT_FAIL, value_to_string(&e)))?,
|
||||
disable_sync: options.disable_sync
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_bytes(bytes: &[u8], options: &SessionOptions) -> Result<Self, Error> {
|
||||
Ok(Session {
|
||||
sentinel: SESSION_SENTINEL,
|
||||
js: binding::InferenceSession::create_from_bytes(
|
||||
// i'm fairly confident that the bytes are copied, at least when we're not using ONNX.js
|
||||
&unsafe { Uint8Array::view(bytes) },
|
||||
&options.js
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Error::new(OrtErrorCode::ORT_FAIL, value_to_string(&e)))?,
|
||||
disable_sync: options.disable_sync
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunOptions {}
|
||||
|
||||
impl RunOptions {
|
||||
pub const fn new() -> Self {
|
||||
RunOptions {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Synchronize all outputs in `outputs` so that their data is available to Rust code.
|
||||
///
|
||||
/// See the [top-level documentation][crate] for more information on synchronization.
|
||||
///
|
||||
/// ```ignore
|
||||
/// let mut outputs = session.run_async(ort::inputs![...]).await?;
|
||||
/// ort_web::sync_outputs(&mut outputs).await?;
|
||||
///
|
||||
/// let bounding_boxes = outputs.remove("bounding_boxes").unwrap();
|
||||
/// ...
|
||||
/// ```
|
||||
pub async fn sync_outputs(outputs: &mut SessionOutputs<'_>) -> crate::Result<()> {
|
||||
for (_, mut value) in outputs.iter_mut() {
|
||||
value.sync(SyncDirection::Rust).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionOptions {
|
||||
pub js: binding::SessionOptions,
|
||||
pub disable_sync: bool
|
||||
}
|
||||
|
||||
impl SessionOptions {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
js: binding::SessionOptions::default(),
|
||||
disable_sync: true
|
||||
}
|
||||
}
|
||||
}
|
||||
253
backends/web/tensor.rs
Normal file
253
backends/web/tensor.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
use alloc::{boxed::Box, vec::Vec};
|
||||
use core::{ffi::c_void, slice};
|
||||
|
||||
use js_sys::Uint8Array;
|
||||
use ort::{AsPointer, value::ValueTypeMarker};
|
||||
use wasm_bindgen::{JsCast, JsValue};
|
||||
|
||||
use crate::{
|
||||
Error,
|
||||
binding::{self, DataType},
|
||||
memory::MemoryInfo,
|
||||
util::num_elements
|
||||
};
|
||||
|
||||
pub const TENSOR_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x39];
|
||||
|
||||
pub enum TensorData {
|
||||
/// Data is stored in WASM linear memory and can be immediately accessed.
|
||||
RustView { ptr: *mut c_void, byte_len: usize },
|
||||
/// Data is stored outside of WASM linear memory (i.e. session output, or a tensor created from anything other than
|
||||
/// a Rust slice) and would need to be retrieved if we try to extract this tensor.
|
||||
External { buffer: Option<Box<[u8]>> }
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Tensor {
|
||||
sentinel: [u8; 4],
|
||||
pub js: binding::Tensor,
|
||||
pub data: TensorData,
|
||||
pub memory_info: MemoryInfo
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub unsafe fn from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize, dims: &[i32]) -> Result<Self, JsValue> {
|
||||
let tensor = binding::Tensor::new_from_buffer(dtype, unsafe { buffer_from_ptr(dtype, ptr, byte_len) }, dims)?;
|
||||
Ok(Self {
|
||||
sentinel: TENSOR_SENTINEL,
|
||||
memory_info: MemoryInfo { location: tensor.location() },
|
||||
js: tensor,
|
||||
data: TensorData::RustView { ptr, byte_len }
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_tensor(tensor: binding::Tensor) -> Self {
|
||||
Self {
|
||||
sentinel: TENSOR_SENTINEL,
|
||||
memory_info: MemoryInfo { location: tensor.location() },
|
||||
js: tensor,
|
||||
data: TensorData::External { buffer: None }
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
|
||||
match direction {
|
||||
SyncDirection::Rust => {
|
||||
let data = self.js.get_data().await?;
|
||||
|
||||
// cast to some kind of typed array first, then convert to uint8array so we can properly copy
|
||||
let generic_typed_array = Uint8Array::unchecked_from_js(data);
|
||||
let bytes = Uint8Array::new_with_byte_offset_and_length(
|
||||
&generic_typed_array.buffer(),
|
||||
generic_typed_array.byte_offset(),
|
||||
generic_typed_array.byte_length()
|
||||
);
|
||||
match &mut self.data {
|
||||
TensorData::RustView { ptr, byte_len } => {
|
||||
bytes.copy_to(unsafe { core::slice::from_raw_parts_mut(ptr.cast(), *byte_len) });
|
||||
}
|
||||
TensorData::External { buffer } => {
|
||||
let buffer = match buffer {
|
||||
Some(buffer) => buffer,
|
||||
None => {
|
||||
*buffer = Some(vec![0; generic_typed_array.byte_length() as usize].into_boxed_slice());
|
||||
unsafe { buffer.as_mut().unwrap_unchecked() }
|
||||
}
|
||||
};
|
||||
bytes.copy_to(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
SyncDirection::Runtime => {
|
||||
let Ok(generic_typed_array) = self.js.data().map(Uint8Array::unchecked_from_js) else {
|
||||
// we have a download function, but no upload...
|
||||
return Err(Error::new(
|
||||
"Cannot synchronize Rust data to a runtime tensor that is not on the CPU; modify the WebGPU/WebGL buffer directly."
|
||||
));
|
||||
};
|
||||
let bytes = Uint8Array::new_with_byte_offset_and_length(
|
||||
&generic_typed_array.buffer(),
|
||||
generic_typed_array.byte_offset(),
|
||||
generic_typed_array.byte_length()
|
||||
);
|
||||
bytes.copy_from(match &self.data {
|
||||
TensorData::RustView { ptr, byte_len } => unsafe { core::slice::from_raw_parts(ptr.cast(), *byte_len) },
|
||||
TensorData::External { buffer } => {
|
||||
let Some(buffer) = buffer else {
|
||||
return Ok(());
|
||||
};
|
||||
&*buffer
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_buffer(dtype: binding::DataType, shape: &[i32]) -> JsValue {
|
||||
let numel = num_elements(shape) as u32;
|
||||
match dtype {
|
||||
binding::DataType::Bool | binding::DataType::Uint8 => js_sys::Uint8Array::new_with_length(numel).into(),
|
||||
binding::DataType::Int8 => js_sys::Int8Array::new_with_length(numel).into(),
|
||||
binding::DataType::Uint16 => js_sys::Uint16Array::new_with_length(numel).into(),
|
||||
binding::DataType::Int16 => js_sys::Int16Array::new_with_length(numel).into(),
|
||||
binding::DataType::Uint32 => js_sys::Uint32Array::new_with_length(numel).into(),
|
||||
binding::DataType::Int32 => js_sys::Int32Array::new_with_length(numel).into(),
|
||||
binding::DataType::Uint64 => js_sys::BigUint64Array::new_with_length(numel).into(),
|
||||
binding::DataType::Int64 => js_sys::BigInt64Array::new_with_length(numel).into(),
|
||||
binding::DataType::Float32 => js_sys::Float32Array::new_with_length(numel).into(),
|
||||
binding::DataType::Float64 => js_sys::Float64Array::new_with_length(numel).into(),
|
||||
binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
|
||||
binding::DataType::__Invalid => unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn buffer_from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize) -> JsValue {
|
||||
match dtype {
|
||||
binding::DataType::Bool | binding::DataType::Uint8 => unsafe { js_sys::Uint8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
|
||||
binding::DataType::Int8 => unsafe { js_sys::Int8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
|
||||
binding::DataType::Uint16 => unsafe { js_sys::Uint16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
|
||||
binding::DataType::Int16 => unsafe { js_sys::Int16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
|
||||
binding::DataType::Uint32 => unsafe { js_sys::Uint32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
|
||||
binding::DataType::Int32 => unsafe { js_sys::Int32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
|
||||
binding::DataType::Uint64 => unsafe { js_sys::BigUint64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
|
||||
binding::DataType::Int64 => unsafe { js_sys::BigInt64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
|
||||
binding::DataType::Float32 => unsafe { js_sys::Float32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
|
||||
binding::DataType::Float64 => unsafe { js_sys::Float64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
|
||||
binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
|
||||
binding::DataType::__Invalid => unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype_to_onnx(dtype: binding::DataType) -> ort_sys::ONNXTensorElementDataType {
|
||||
match dtype {
|
||||
binding::DataType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
|
||||
binding::DataType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
|
||||
binding::DataType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
|
||||
binding::DataType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
|
||||
binding::DataType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
|
||||
binding::DataType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
|
||||
binding::DataType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
|
||||
binding::DataType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
|
||||
binding::DataType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
|
||||
binding::DataType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
|
||||
binding::DataType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
|
||||
binding::DataType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
|
||||
binding::DataType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
|
||||
binding::DataType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4,
|
||||
binding::DataType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4,
|
||||
binding::DataType::__Invalid => unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn onnx_to_dtype(dtype: ort_sys::ONNXTensorElementDataType) -> Option<binding::DataType> {
|
||||
match dtype {
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Some(binding::DataType::String),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Some(binding::DataType::Bool),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Some(binding::DataType::Uint8),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Some(binding::DataType::Int8),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Some(binding::DataType::Uint16),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Some(binding::DataType::Int16),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Some(binding::DataType::Uint32),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Some(binding::DataType::Int32),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Some(binding::DataType::Uint64),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Some(binding::DataType::Int64),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Some(binding::DataType::Float16),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Some(binding::DataType::Float32),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Some(binding::DataType::Float64),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => Some(binding::DataType::Int4),
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => Some(binding::DataType::Uint4),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TypeInfo {
|
||||
pub dtype: ort_sys::ONNXTensorElementDataType,
|
||||
pub shape: Vec<i32>
|
||||
}
|
||||
|
||||
impl TypeInfo {
|
||||
pub fn new_sys_from_tensor(tensor: &Tensor) -> *mut ort_sys::OrtTypeInfo {
|
||||
Self::new_sys(tensor.js.dtype(), tensor.js.dims())
|
||||
}
|
||||
|
||||
pub fn new_sys_from_value_metadata(metadata: &binding::ValueMetadata) -> *mut ort_sys::OrtTypeInfo {
|
||||
Self::new_sys(
|
||||
metadata.r#type.unwrap(),
|
||||
metadata
|
||||
.shape
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|el| match el {
|
||||
binding::ShapeElement::Value(v) => *v as i32,
|
||||
binding::ShapeElement::Named(_) => -1
|
||||
})
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_sys(dtype: DataType, shape: Vec<i32>) -> *mut ort_sys::OrtTypeInfo {
|
||||
(Box::leak(Box::new(Self { dtype: dtype_to_onnx(dtype), shape })) as *mut TypeInfo).cast()
|
||||
}
|
||||
|
||||
pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box<TypeInfo> {
|
||||
unsafe { Box::from_raw(ptr.cast::<TypeInfo>()) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SyncDirection {
|
||||
/// Synchronize tensor data from the device/runtime so that it is accessible to Rust code.
|
||||
Rust,
|
||||
/// Synchronize tensor data from Rust code so that it is accessible to the runtime.
|
||||
Runtime
|
||||
}
|
||||
|
||||
pub trait ValueExt {
|
||||
crate::private_trait!();
|
||||
|
||||
/// Synchronize data between Rust & the runtime.
|
||||
///
|
||||
/// See the [top-level documentation][crate] for more information on synchronization.
|
||||
#[allow(async_fn_in_trait)]
|
||||
async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()>;
|
||||
}
|
||||
|
||||
impl<T: ValueTypeMarker> ValueExt for ort::value::Value<T> {
|
||||
crate::private_impl!();
|
||||
|
||||
async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
|
||||
let ptr = self.ptr_mut();
|
||||
// definitely safe regardless of what backend is used since it's highly improbable that a backend's tensor would be
|
||||
// smaller than 4 bytes (which is pointer size on wasm32)
|
||||
let sentinel: [u8; 4] = unsafe { core::ptr::read(ptr.cast()) };
|
||||
if sentinel != TENSOR_SENTINEL {
|
||||
return Err(Error::new("Cannot synchronize Value that was not created by ort-web"));
|
||||
}
|
||||
|
||||
let tensor: &mut Tensor = unsafe { &mut *ptr.cast() };
|
||||
tensor.sync(direction).await
|
||||
}
|
||||
}
|
||||
18
backends/web/util.rs
Normal file
18
backends/web/util.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use alloc::string::String;
|
||||
|
||||
use wasm_bindgen::{JsCast, JsValue};
|
||||
|
||||
pub fn value_to_string(value: &JsValue) -> String {
|
||||
js_sys::Object::unchecked_from_js_ref(value).to_string().into()
|
||||
}
|
||||
|
||||
pub fn num_elements(dims: &[i32]) -> usize {
|
||||
let mut size = 1usize;
|
||||
for dim in dims {
|
||||
if *dim < 0 {
|
||||
return 0;
|
||||
}
|
||||
size *= *dim as usize;
|
||||
}
|
||||
size
|
||||
}
|
||||
Reference in New Issue
Block a user