mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
@@ -2,7 +2,9 @@
|
||||
|
||||
use std::{
|
||||
ffi::{CStr, CString, OsString},
|
||||
fs, ptr
|
||||
fs,
|
||||
path::Path,
|
||||
ptr
|
||||
};
|
||||
|
||||
use ort_sys::*;
|
||||
@@ -71,12 +73,12 @@ unsafe extern "system" fn CreateSession(
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::<Vec<_>>());
|
||||
|
||||
let buf = match fs::read(path) {
|
||||
let buf = match fs::read(&path) {
|
||||
Ok(buf) => buf,
|
||||
Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}"))
|
||||
};
|
||||
|
||||
match Session::from_buffer(env, options, &buf) {
|
||||
match Session::from_buffer(env, options, &buf, Some(Path::new(&path))) {
|
||||
Ok(session) => {
|
||||
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
|
||||
OrtStatusPtr::default()
|
||||
@@ -97,7 +99,7 @@ unsafe extern "system" fn CreateSessionFromArray(
|
||||
|
||||
let buf = std::slice::from_raw_parts(model_data.cast::<u8>(), model_data_length);
|
||||
|
||||
match Session::from_buffer(env, options, buf) {
|
||||
match Session::from_buffer(env, options, buf, None) {
|
||||
Ok(session) => {
|
||||
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
|
||||
OrtStatusPtr::default()
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
use std::{
|
||||
collections::{HashMap, hash_map::Entry},
|
||||
hash::{BuildHasher, DefaultHasher, Hasher},
|
||||
path::Path,
|
||||
sync::Arc
|
||||
};
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use tract_onnx::{
|
||||
model::ParseResult,
|
||||
pb::ValueInfoProto,
|
||||
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractResult, TypedFact, TypedOp}
|
||||
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractError, TractResult, TypedFact, TypedOp}
|
||||
};
|
||||
|
||||
use crate::Environment;
|
||||
@@ -54,12 +56,15 @@ pub struct Session {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8]) -> TractResult<Session> {
|
||||
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8], path: Option<&Path>) -> TractResult<Session> {
|
||||
let proto_model = env.onnx.proto_model_for_read(&mut data)?;
|
||||
let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default();
|
||||
let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default();
|
||||
|
||||
let model = env.onnx.model_for_proto_model(&proto_model)?;
|
||||
let ParseResult { model, unresolved_inputs, .. } = env.onnx.parse(&proto_model, path.and_then(|p| p.parent()).and_then(|p| p.to_str()))?;
|
||||
if unresolved_inputs.len() > 0 {
|
||||
return Err(TractError::msg("failed to resolve some inputs"));
|
||||
}
|
||||
let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? });
|
||||
Ok(Session {
|
||||
inputs,
|
||||
|
||||
@@ -12,7 +12,7 @@ unsafe extern "system" fn get_api(version: u32) -> *const ort_sys::OrtApi {
|
||||
if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() }
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase {
|
||||
&API_BASE as *const _
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user