diff --git a/backends/tract/api.rs b/backends/tract/api.rs index fd1649d..e9c0cd7 100644 --- a/backends/tract/api.rs +++ b/backends/tract/api.rs @@ -2,7 +2,9 @@ use std::{ ffi::{CStr, CString, OsString}, - fs, ptr + fs, + path::Path, + ptr }; use ort_sys::*; @@ -71,12 +73,12 @@ unsafe extern "system" fn CreateSession( #[cfg(not(target_os = "windows"))] let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::>()); - let buf = match fs::read(path) { + let buf = match fs::read(&path) { Ok(buf) => buf, Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}")) }; - match Session::from_buffer(env, options, &buf) { + match Session::from_buffer(env, options, &buf, Some(Path::new(&path))) { Ok(session) => { *out = (Box::leak(Box::new(session)) as *mut Session).cast(); OrtStatusPtr::default() @@ -97,7 +99,7 @@ unsafe extern "system" fn CreateSessionFromArray( let buf = std::slice::from_raw_parts(model_data.cast::(), model_data_length); - match Session::from_buffer(env, options, buf) { + match Session::from_buffer(env, options, buf, None) { Ok(session) => { *out = (Box::leak(Box::new(session)) as *mut Session).cast(); OrtStatusPtr::default() diff --git a/backends/tract/session.rs b/backends/tract/session.rs index 0741e41..1cf074e 100644 --- a/backends/tract/session.rs +++ b/backends/tract/session.rs @@ -1,13 +1,15 @@ use std::{ collections::{HashMap, hash_map::Entry}, hash::{BuildHasher, DefaultHasher, Hasher}, + path::Path, sync::Arc }; use parking_lot::Mutex; use tract_onnx::{ + model::ParseResult, pb::ValueInfoProto, - prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractResult, TypedFact, TypedOp} + prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractError, TractResult, TypedFact, TypedOp} }; use crate::Environment; @@ -54,12 +56,15 @@ pub struct Session { } impl Session { - pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8]) -> TractResult { + pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8], path: Option<&Path>) -> TractResult { let proto_model = env.onnx.proto_model_for_read(&mut data)?; let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default(); let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default(); - let model = env.onnx.model_for_proto_model(&proto_model)?; + let ParseResult { model, unresolved_inputs, .. } = env.onnx.parse(&proto_model, path.and_then(|p| p.parent()).and_then(|p| p.to_str()))?; + if unresolved_inputs.len() > 0 { + return Err(TractError::msg("failed to resolve some inputs")); + } let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? }); Ok(Session { inputs, diff --git a/backends/tract/standalone/lib.rs b/backends/tract/standalone/lib.rs index 00f3b70..9114d00 100644 --- a/backends/tract/standalone/lib.rs +++ b/backends/tract/standalone/lib.rs @@ -12,7 +12,7 @@ unsafe extern "system" fn get_api(version: u32) -> *const ort_sys::OrtApi { if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() } } -#[no_mangle] +#[unsafe(no_mangle)] pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase { &API_BASE as *const _ }