mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
test: add doctest for overridable_initializers, opset_for_domain
This commit is contained in:
@@ -153,7 +153,6 @@ impl Session {
|
||||
}
|
||||
|
||||
/// Returns this session's [`Allocator`].
|
||||
#[must_use]
|
||||
pub fn allocator(&self) -> &Allocator {
|
||||
&self.inner.allocator
|
||||
}
|
||||
@@ -171,6 +170,21 @@ impl Session {
|
||||
}
|
||||
|
||||
/// Returns a list of initializers which are overridable (i.e. also graph inputs).
|
||||
///
|
||||
/// ```
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{session::{RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?.commit_from_file("tests/data/overridable_initializer.onnx")?;
|
||||
///
|
||||
/// let mut overridable_initializers = session.overridable_initializers();
|
||||
/// assert_eq!(overridable_initializers.len(), 1);
|
||||
/// let f1 = overridable_initializers.pop().unwrap();
|
||||
/// assert_eq!(f1.name(), "F1");
|
||||
/// assert!(f1.dtype().is_tensor());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn overridable_initializers(&self) -> Vec<OverridableInitializer> {
|
||||
// can only fail if:
|
||||
@@ -608,7 +622,7 @@ impl Session {
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Ends profiling for this session.
|
||||
/// Ends profiling for this session. Returns the file name of the finalized profile.
|
||||
///
|
||||
/// Note that this must be explicitly called at the end of profiling, otherwise the profiling file will be empty.
|
||||
pub fn end_profiling(&mut self) -> Result<String> {
|
||||
@@ -648,14 +662,28 @@ impl Session {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the version of the opset domain used by the model.
|
||||
///
|
||||
/// Requires the Model Editor API to be supported by the backend.
|
||||
///
|
||||
/// ```
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{session::{RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||
/// assert_eq!(session.opset_for_domain(ort::editor::ONNX_DOMAIN), Some(21));
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg(feature = "api-22")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||
pub fn opset_for_domain(&self, domain: impl AsRef<str>) -> Result<u32> {
|
||||
pub fn opset_for_domain(&self, domain: impl AsRef<str>) -> Option<u32> {
|
||||
with_cstr(domain.as_ref().as_bytes(), &|domain| {
|
||||
let mut opset = 0;
|
||||
ortsys![@editor: unsafe SessionGetOpsetForDomain(self.inner.session_ptr.as_ptr(), domain.as_ptr(), &mut opset)?];
|
||||
Ok(opset as u32)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
BIN
tests/data/overridable_initializer.onnx
Normal file
BIN
tests/data/overridable_initializer.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user