fix(tract): support external data

closes #543
This commit is contained in:
Carson M.
2026-03-03 13:10:23 -06:00
parent e9666c7f00
commit 00231247a2
3 changed files with 15 additions and 8 deletions

View File

@@ -2,7 +2,9 @@
use std::{
ffi::{CStr, CString, OsString},
fs, ptr
fs,
path::Path,
ptr
};
use ort_sys::*;
@@ -71,12 +73,12 @@ unsafe extern "system" fn CreateSession(
#[cfg(not(target_os = "windows"))]
let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::<Vec<_>>());
let buf = match fs::read(path) {
let buf = match fs::read(&path) {
Ok(buf) => buf,
Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}"))
};
match Session::from_buffer(env, options, &buf) {
match Session::from_buffer(env, options, &buf, Some(Path::new(&path))) {
Ok(session) => {
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
OrtStatusPtr::default()
@@ -97,7 +99,7 @@ unsafe extern "system" fn CreateSessionFromArray(
let buf = std::slice::from_raw_parts(model_data.cast::<u8>(), model_data_length);
match Session::from_buffer(env, options, buf) {
match Session::from_buffer(env, options, buf, None) {
Ok(session) => {
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
OrtStatusPtr::default()

View File

@@ -1,13 +1,15 @@
use std::{
collections::{HashMap, hash_map::Entry},
hash::{BuildHasher, DefaultHasher, Hasher},
path::Path,
sync::Arc
};
use parking_lot::Mutex;
use tract_onnx::{
model::ParseResult,
pb::ValueInfoProto,
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractResult, TypedFact, TypedOp}
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractError, TractResult, TypedFact, TypedOp}
};
use crate::Environment;
@@ -54,12 +56,15 @@ pub struct Session {
}
impl Session {
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8]) -> TractResult<Session> {
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8], path: Option<&Path>) -> TractResult<Session> {
let proto_model = env.onnx.proto_model_for_read(&mut data)?;
let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default();
let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default();
let model = env.onnx.model_for_proto_model(&proto_model)?;
let ParseResult { model, unresolved_inputs, .. } = env.onnx.parse(&proto_model, path.and_then(|p| p.parent()).and_then(|p| p.to_str()))?;
if unresolved_inputs.len() > 0 {
return Err(TractError::msg("failed to resolve some inputs"));
}
let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? });
Ok(Session {
inputs,

View File

@@ -12,7 +12,7 @@ unsafe extern "system" fn get_api(version: u32) -> *const ort_sys::OrtApi {
if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() }
}
#[no_mangle]
#[unsafe(no_mangle)]
pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase {
&API_BASE as *const _
}