mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
test: add doctest for overridable_initializers, opset_for_domain
This commit is contained in:
@@ -153,7 +153,6 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns this session's [`Allocator`].
|
/// Returns this session's [`Allocator`].
|
||||||
#[must_use]
|
|
||||||
pub fn allocator(&self) -> &Allocator {
|
pub fn allocator(&self) -> &Allocator {
|
||||||
&self.inner.allocator
|
&self.inner.allocator
|
||||||
}
|
}
|
||||||
@@ -171,6 +170,21 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a list of initializers which are overridable (i.e. also graph inputs).
|
/// Returns a list of initializers which are overridable (i.e. also graph inputs).
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use std::sync::Arc;
|
||||||
|
/// # use ort::{session::{RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}};
|
||||||
|
/// # fn main() -> ort::Result<()> {
|
||||||
|
/// let session = Session::builder()?.commit_from_file("tests/data/overridable_initializer.onnx")?;
|
||||||
|
///
|
||||||
|
/// let mut overridable_initializers = session.overridable_initializers();
|
||||||
|
/// assert_eq!(overridable_initializers.len(), 1);
|
||||||
|
/// let f1 = overridable_initializers.pop().unwrap();
|
||||||
|
/// assert_eq!(f1.name(), "F1");
|
||||||
|
/// assert!(f1.dtype().is_tensor());
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn overridable_initializers(&self) -> Vec<OverridableInitializer> {
|
pub fn overridable_initializers(&self) -> Vec<OverridableInitializer> {
|
||||||
// can only fail if:
|
// can only fail if:
|
||||||
@@ -608,7 +622,7 @@ impl Session {
|
|||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ends profiling for this session.
|
/// Ends profiling for this session. Returns the file name of the finalized profile.
|
||||||
///
|
///
|
||||||
/// Note that this must be explicitly called at the end of profiling, otherwise the profiling file will be empty.
|
/// Note that this must be explicitly called at the end of profiling, otherwise the profiling file will be empty.
|
||||||
pub fn end_profiling(&mut self) -> Result<String> {
|
pub fn end_profiling(&mut self) -> Result<String> {
|
||||||
@@ -648,14 +662,28 @@ impl Session {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the version of the opset domain used by the model.
|
||||||
|
///
|
||||||
|
/// Requires the Model Editor API to be supported by the backend.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use std::sync::Arc;
|
||||||
|
/// # use ort::{session::{RunOptions, Session}, value::{Value, ValueType, TensorRef, TensorElementType}};
|
||||||
|
/// # fn main() -> ort::Result<()> {
|
||||||
|
/// let session = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||||
|
/// assert_eq!(session.opset_for_domain(ort::editor::ONNX_DOMAIN), Some(21));
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
#[cfg(feature = "api-22")]
|
#[cfg(feature = "api-22")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
|
||||||
pub fn opset_for_domain(&self, domain: impl AsRef<str>) -> Result<u32> {
|
pub fn opset_for_domain(&self, domain: impl AsRef<str>) -> Option<u32> {
|
||||||
with_cstr(domain.as_ref().as_bytes(), &|domain| {
|
with_cstr(domain.as_ref().as_bytes(), &|domain| {
|
||||||
let mut opset = 0;
|
let mut opset = 0;
|
||||||
ortsys![@editor: unsafe SessionGetOpsetForDomain(self.inner.session_ptr.as_ptr(), domain.as_ptr(), &mut opset)?];
|
ortsys![@editor: unsafe SessionGetOpsetForDomain(self.inner.session_ptr.as_ptr(), domain.as_ptr(), &mut opset)?];
|
||||||
Ok(opset as u32)
|
Ok(opset as u32)
|
||||||
})
|
})
|
||||||
|
.ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
BIN
tests/data/overridable_initializer.onnx
Normal file
BIN
tests/data/overridable_initializer.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user