mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: zero-copy deserialization for in-memory session
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -185,4 +185,5 @@ WixTools/
|
||||
|
||||
# ONNX Runtime downloaded models
|
||||
**/*.onnx
|
||||
**/*.ort
|
||||
!tests/data/*.onnx
|
||||
|
||||
@@ -17,12 +17,12 @@ use std::{
|
||||
sync::{atomic::AtomicPtr, Arc, Mutex}
|
||||
};
|
||||
|
||||
pub use environment::Environment;
|
||||
pub use error::{OrtApiError, OrtError, OrtResult};
|
||||
pub use execution_providers::ExecutionProvider;
|
||||
use lazy_static::lazy_static;
|
||||
pub use session::{Session, SessionBuilder};
|
||||
|
||||
pub use self::environment::Environment;
|
||||
pub use self::error::{OrtApiError, OrtError, OrtResult};
|
||||
pub use self::execution_providers::ExecutionProvider;
|
||||
pub use self::session::{InMemorySession, Session, SessionBuilder};
|
||||
use self::sys::OnnxEnumInt;
|
||||
|
||||
macro_rules! extern_system_fn {
|
||||
|
||||
@@ -11,6 +11,8 @@ use std::{env, path::PathBuf, time::Duration};
|
||||
use std::{
|
||||
ffi::CString,
|
||||
fmt::{self, Debug},
|
||||
marker::PhantomData,
|
||||
ops::Deref,
|
||||
os::raw::c_char,
|
||||
path::Path,
|
||||
sync::Arc
|
||||
@@ -392,15 +394,8 @@ impl SessionBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
/// Load an ONNX graph from memory and commit the session
|
||||
pub fn with_model_from_memory<B>(self, model_bytes: B) -> OrtResult<Session>
|
||||
where
|
||||
B: AsRef<[u8]>
|
||||
{
|
||||
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
|
||||
}
|
||||
|
||||
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> OrtResult<Session> {
|
||||
/// Load an ONNX graph from memory and commit the session.
|
||||
pub fn with_model_from_memory<'s>(self, model_bytes: &'s [u8]) -> OrtResult<InMemorySession<'s>> {
|
||||
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
|
||||
|
||||
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
|
||||
@@ -414,6 +409,17 @@ impl SessionBuilder {
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let str_to_char = |s: &str| {
|
||||
s.as_bytes()
|
||||
.iter()
|
||||
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
|
||||
.map(|b| *b as std::os::raw::c_char)
|
||||
.collect::<Vec<std::os::raw::c_char>>()
|
||||
};
|
||||
// Enable zero-copy deserialization for models in `.ort` format.
|
||||
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr, str_to_char("session.use_ort_model_bytes_directly").as_ptr(), str_to_char("1").as_ptr())];
|
||||
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr, str_to_char("session.use_ort_model_bytes_for_initializers").as_ptr(), str_to_char("1").as_ptr())];
|
||||
|
||||
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
|
||||
let model_data_length = model_bytes.len();
|
||||
ortsys![
|
||||
@@ -436,14 +442,15 @@ impl SessionBuilder {
|
||||
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
|
||||
.collect::<OrtResult<Vec<Output>>>()?;
|
||||
|
||||
Ok(Session {
|
||||
let session = Session {
|
||||
env: Arc::clone(&self.env),
|
||||
session_ptr,
|
||||
allocator_ptr,
|
||||
memory_info,
|
||||
inputs,
|
||||
outputs
|
||||
})
|
||||
};
|
||||
Ok(InMemorySession { session, phantom: PhantomData })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,6 +468,19 @@ pub struct Session {
|
||||
pub outputs: Vec<Output>
|
||||
}
|
||||
|
||||
/// A [`Session`] with data stored in-memory.
|
||||
pub struct InMemorySession<'s> {
|
||||
session: Session,
|
||||
phantom: PhantomData<&'s ()>
|
||||
}
|
||||
|
||||
impl<'s> Deref for InMemorySession<'s> {
|
||||
type Target = Session;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.session
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about an ONNX's input as stored in loaded file
|
||||
#[derive(Debug)]
|
||||
pub struct Input {
|
||||
|
||||
@@ -186,11 +186,13 @@ mod download {
|
||||
.build()?
|
||||
.into_arc();
|
||||
|
||||
let session_data =
|
||||
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
|
||||
let session = SessionBuilder::new(&environment)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx"))
|
||||
.expect("Could not open model from file");
|
||||
.with_model_from_memory(&session_data)
|
||||
.expect("Could not read model from memory");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
assert_eq!(metadata.name()?, "tf2onnx");
|
||||
|
||||
Reference in New Issue
Block a user