From 017f7aa604657acc79de74ebfdeef3b39cdccf43 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 6 Apr 2026 17:44:24 -0500 Subject: [PATCH] test: add doctest for `overridable_initializers`, `opset_for_domain` --- src/session/mod.rs | 34 +++++++++++++++++++++--- tests/data/overridable_initializer.onnx | Bin 0 -> 303 bytes 2 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 tests/data/overridable_initializer.onnx diff --git a/src/session/mod.rs b/src/session/mod.rs index 056fe20..ef41227 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -153,7 +153,6 @@ impl Session { } /// Returns this session's [`Allocator`]. - #[must_use] pub fn allocator(&self) -> &Allocator { &self.inner.allocator } @@ -171,6 +170,21 @@ impl Session { } /// Returns a list of initializers which are overridable (i.e. also graph inputs). + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{session::{RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/overridable_initializer.onnx")?; + /// + /// let mut overridable_initializers = session.overridable_initializers(); + /// assert_eq!(overridable_initializers.len(), 1); + /// let f1 = overridable_initializers.pop().unwrap(); + /// assert_eq!(f1.name(), "F1"); + /// assert!(f1.dtype().is_tensor()); + /// # Ok(()) + /// # } + /// ``` #[must_use] pub fn overridable_initializers(&self) -> Vec { // can only fail if: @@ -608,7 +622,7 @@ impl Session { Ok(out) } - /// Ends profiling for this session. + /// Ends profiling for this session. Returns the file name of the finalized profile. /// /// Note that this must be explicitly called at the end of profiling, otherwise the profiling file will be empty. pub fn end_profiling(&mut self) -> Result { @@ -648,14 +662,28 @@ impl Session { Ok(()) } + /// Returns the version of the opset domain used by the model. + /// + /// Requires the Model Editor API to be supported by the backend. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{session::{RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?; + /// assert_eq!(session.opset_for_domain(ort::editor::ONNX_DOMAIN), Some(21)); + /// # Ok(()) + /// # } + /// ``` #[cfg(feature = "api-22")] #[cfg_attr(docsrs, doc(cfg(feature = "api-22")))] - pub fn opset_for_domain(&self, domain: impl AsRef) -> Result { + pub fn opset_for_domain(&self, domain: impl AsRef) -> Option { with_cstr(domain.as_ref().as_bytes(), &|domain| { let mut opset = 0; ortsys![@editor: unsafe SessionGetOpsetForDomain(self.inner.session_ptr.as_ptr(), domain.as_ptr(), &mut opset)?]; Ok(opset as u32) }) + .ok() } } diff --git a/tests/data/overridable_initializer.onnx b/tests/data/overridable_initializer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5c80850469ab3f3c504b17375ac0cc9915d26ee4 GIT binary patch literal 303 zcmd;J6XHn8%`7R(EY7u>#l)q|#p;uol$s;N2Br+8I6PBQ^GY&HDwW_YE;%kHHzOft zHzNZnPMEYIn!F)U-p~*sZ-9^&63H(p$;{77%!$v;D=00APcKR=$j}ntVB}yFU{qpZ z00Ji_H^V4#m}B_3csMwPIJlTN7(qBn1ndBiC