feat: zero-copy deserialization for in-memory session

This commit is contained in:
Carson M
2023-02-27 14:45:27 -06:00
parent 957b1aa1dc
commit 7e2e2b5f1b
4 changed files with 40 additions and 17 deletions

1
.gitignore vendored
View File

@@ -185,4 +185,5 @@ WixTools/
# ONNX Runtime downloaded models
**/*.onnx
**/*.ort
!tests/data/*.onnx

View File

@@ -17,12 +17,12 @@ use std::{
sync::{atomic::AtomicPtr, Arc, Mutex}
};
pub use environment::Environment;
pub use error::{OrtApiError, OrtError, OrtResult};
pub use execution_providers::ExecutionProvider;
use lazy_static::lazy_static;
pub use session::{Session, SessionBuilder};
pub use self::environment::Environment;
pub use self::error::{OrtApiError, OrtError, OrtResult};
pub use self::execution_providers::ExecutionProvider;
pub use self::session::{InMemorySession, Session, SessionBuilder};
use self::sys::OnnxEnumInt;
macro_rules! extern_system_fn {

View File

@@ -11,6 +11,8 @@ use std::{env, path::PathBuf, time::Duration};
use std::{
ffi::CString,
fmt::{self, Debug},
marker::PhantomData,
ops::Deref,
os::raw::c_char,
path::Path,
sync::Arc
@@ -392,15 +394,8 @@ impl SessionBuilder {
})
}
/// Load an ONNX graph from memory and commit the session
pub fn with_model_from_memory<B>(self, model_bytes: B) -> OrtResult<Session>
where
B: AsRef<[u8]>
{
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
}
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> OrtResult<Session> {
/// Load an ONNX graph from memory and commit the session.
pub fn with_model_from_memory<'s>(self, model_bytes: &'s [u8]) -> OrtResult<InMemorySession<'s>> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
@@ -414,6 +409,17 @@ impl SessionBuilder {
.collect::<Vec<_>>()
);
let str_to_char = |s: &str| {
s.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
.map(|b| *b as std::os::raw::c_char)
.collect::<Vec<std::os::raw::c_char>>()
};
// Enable zero-copy deserialization for models in `.ort` format.
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr, str_to_char("session.use_ort_model_bytes_directly").as_ptr(), str_to_char("1").as_ptr())];
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr, str_to_char("session.use_ort_model_bytes_for_initializers").as_ptr(), str_to_char("1").as_ptr())];
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len();
ortsys![
@@ -436,14 +442,15 @@ impl SessionBuilder {
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
let session = Session {
env: Arc::clone(&self.env),
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs
})
};
Ok(InMemorySession { session, phantom: PhantomData })
}
}
@@ -461,6 +468,19 @@ pub struct Session {
pub outputs: Vec<Output>
}
/// A [`Session`] with data stored in-memory.
pub struct InMemorySession<'s> {
session: Session,
phantom: PhantomData<&'s ()>
}
impl<'s> Deref for InMemorySession<'s> {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
/// Information about an ONNX's input as stored in loaded file
#[derive(Debug)]
pub struct Input {

View File

@@ -186,11 +186,13 @@ mod download {
.build()?
.into_arc();
let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
let session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx"))
.expect("Could not open model from file");
.with_model_from_memory(&session_data)
.expect("Could not read model from memory");
let metadata = session.metadata()?;
assert_eq!(metadata.name()?, "tf2onnx");