mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
ffi::{CStr, CString, OsString},
|
ffi::{CStr, CString, OsString},
|
||||||
fs, ptr
|
fs,
|
||||||
|
path::Path,
|
||||||
|
ptr
|
||||||
};
|
};
|
||||||
|
|
||||||
use ort_sys::*;
|
use ort_sys::*;
|
||||||
@@ -71,12 +73,12 @@ unsafe extern "system" fn CreateSession(
|
|||||||
#[cfg(not(target_os = "windows"))]
|
#[cfg(not(target_os = "windows"))]
|
||||||
let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::<Vec<_>>());
|
let path = OsString::from_encoded_bytes_unchecked(path.iter().map(|c| *c as u8).collect::<Vec<_>>());
|
||||||
|
|
||||||
let buf = match fs::read(path) {
|
let buf = match fs::read(&path) {
|
||||||
Ok(buf) => buf,
|
Ok(buf) => buf,
|
||||||
Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}"))
|
Err(e) => return Error::new_sys(OrtErrorCode::ORT_NO_SUCHFILE, format!("Failed to read model file: {e}"))
|
||||||
};
|
};
|
||||||
|
|
||||||
match Session::from_buffer(env, options, &buf) {
|
match Session::from_buffer(env, options, &buf, Some(Path::new(&path))) {
|
||||||
Ok(session) => {
|
Ok(session) => {
|
||||||
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
|
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
|
||||||
OrtStatusPtr::default()
|
OrtStatusPtr::default()
|
||||||
@@ -97,7 +99,7 @@ unsafe extern "system" fn CreateSessionFromArray(
|
|||||||
|
|
||||||
let buf = std::slice::from_raw_parts(model_data.cast::<u8>(), model_data_length);
|
let buf = std::slice::from_raw_parts(model_data.cast::<u8>(), model_data_length);
|
||||||
|
|
||||||
match Session::from_buffer(env, options, buf) {
|
match Session::from_buffer(env, options, buf, None) {
|
||||||
Ok(session) => {
|
Ok(session) => {
|
||||||
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
|
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
|
||||||
OrtStatusPtr::default()
|
OrtStatusPtr::default()
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, hash_map::Entry},
|
collections::{HashMap, hash_map::Entry},
|
||||||
hash::{BuildHasher, DefaultHasher, Hasher},
|
hash::{BuildHasher, DefaultHasher, Hasher},
|
||||||
|
path::Path,
|
||||||
sync::Arc
|
sync::Arc
|
||||||
};
|
};
|
||||||
|
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use tract_onnx::{
|
use tract_onnx::{
|
||||||
|
model::ParseResult,
|
||||||
pb::ValueInfoProto,
|
pb::ValueInfoProto,
|
||||||
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractResult, TypedFact, TypedOp}
|
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractError, TractResult, TypedFact, TypedOp}
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::Environment;
|
use crate::Environment;
|
||||||
@@ -54,12 +56,15 @@ pub struct Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8]) -> TractResult<Session> {
|
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8], path: Option<&Path>) -> TractResult<Session> {
|
||||||
let proto_model = env.onnx.proto_model_for_read(&mut data)?;
|
let proto_model = env.onnx.proto_model_for_read(&mut data)?;
|
||||||
let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default();
|
let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default();
|
||||||
let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default();
|
let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default();
|
||||||
|
|
||||||
let model = env.onnx.model_for_proto_model(&proto_model)?;
|
let ParseResult { model, unresolved_inputs, .. } = env.onnx.parse(&proto_model, path.and_then(|p| p.parent()).and_then(|p| p.to_str()))?;
|
||||||
|
if unresolved_inputs.len() > 0 {
|
||||||
|
return Err(TractError::msg("failed to resolve some inputs"));
|
||||||
|
}
|
||||||
let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? });
|
let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? });
|
||||||
Ok(Session {
|
Ok(Session {
|
||||||
inputs,
|
inputs,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ unsafe extern "system" fn get_api(version: u32) -> *const ort_sys::OrtApi {
|
|||||||
if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() }
|
if version <= ort_sys::ORT_API_VERSION { &API as *const _ } else { core::ptr::null() }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[no_mangle]
|
#[unsafe(no_mangle)]
|
||||||
pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase {
|
pub unsafe extern "C" fn OrtGetApiBase() -> *const ort_sys::OrtApiBase {
|
||||||
&API_BASE as *const _
|
&API_BASE as *const _
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user