fix(tract): support external data

closes #543
This commit is contained in:
Carson M.
2026-03-03 13:10:23 -06:00
parent e9666c7f00
commit 00231247a2
3 changed files with 15 additions and 8 deletions

View File

@@ -2,7 +2,9 @@
use std::{ use std::{
ffi::{CStr, CString, OsString}, ffi::{CStr, CString, OsString},
fs, ptr fs,
path::Path,
ptr
}; };
use ort_sys::*; use ort_sys::*;
@@ -71,12 +73,12 @@ unsafe extern "system" fn CreateSession(
#[cfg(not(target_os = "windows"))] #[cfg(not(target_os = "windows"))]
let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::<Vec<_>>()); let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::<Vec<_>>());
let buf = match fs::read(path) { let buf = match fs::read(&path) {
Ok(buf) => buf, Ok(buf) => buf,
Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}")) Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}"))
}; };
match Session::from_buffer(env, options, &buf) { match Session::from_buffer(env, options, &buf, Some(Path::new(&path))) {
Ok(session) => { Ok(session) => {
*out = (Box::leak(Box::new(session)) as *mut Session).cast(); *out = (Box::leak(Box::new(session)) as *mut Session).cast();
OrtStatusPtr::default() OrtStatusPtr::default()
@@ -97,7 +99,7 @@ unsafe extern "system" fn CreateSessionFromArray(
let buf = std::slice::from_raw_parts(model_data.cast::<u8>(), model_data_length); let buf = std::slice::from_raw_parts(model_data.cast::<u8>(), model_data_length);
match Session::from_buffer(env, options, buf) { match Session::from_buffer(env, options, buf, None) {
Ok(session) => { Ok(session) => {
*out = (Box::leak(Box::new(session)) as *mut Session).cast(); *out = (Box::leak(Box::new(session)) as *mut Session).cast();
OrtStatusPtr::default() OrtStatusPtr::default()

View File

@@ -1,13 +1,15 @@
use std::{ use std::{
collections::{HashMap, hash_map::Entry}, collections::{HashMap, hash_map::Entry},
hash::{BuildHasher, DefaultHasher, Hasher}, hash::{BuildHasher, DefaultHasher, Hasher},
path::Path,
sync::Arc sync::Arc
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
use tract_onnx::{ use tract_onnx::{
model::ParseResult,
pb::ValueInfoProto, pb::ValueInfoProto,
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractResult, TypedFact, TypedOp} prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractError, TractResult, TypedFact, TypedOp}
}; };
use crate::Environment; use crate::Environment;
@@ -54,12 +56,15 @@ pub struct Session {
} }
impl Session { impl Session {
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8]) -> TractResult<Session> { pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8], path: Option<&Path>) -> TractResult<Session> {
let proto_model = env.onnx.proto_model_for_read(&mut data)?; let proto_model = env.onnx.proto_model_for_read(&mut data)?;
let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default(); let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default();
let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default(); let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default();
let model = env.onnx.model_for_proto_model(&proto_model)?; let ParseResult { model, unresolved_inputs, .. } = env.onnx.parse(&proto_model, path.and_then(|p| p.parent()).and_then(|p| p.to_str()))?;
if unresolved_inputs.len() > 0 {
return Err(TractError::msg("failed to resolve some inputs"));
}
let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? }); let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? });
Ok(Session { Ok(Session {
inputs, inputs,

View File

@@ -12,7 +12,7 @@ unsafe extern "system" fn get_api(version: u32) -> *const ort_sys::OrtApi {
if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() } if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() }
} }
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase { pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase {
&API_BASE as *const _ &API_BASE as *const _
} }