mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: rename SessionBuilder commit methods
This commit is contained in:
@@ -13,7 +13,7 @@ fn load_squeezenet_data() -> ort::Result<(Session, Array4<f32>)> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx")
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx")
|
||||
.expect("Could not download model from file");
|
||||
|
||||
let input0_shape: &Vec<i64> = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type");
|
||||
|
||||
@@ -65,7 +65,7 @@ Users of `ort` appreciate its ease of use and ergonomic API. `ort` is also battl
|
||||
let model = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?
|
||||
.with_model_from_file("yolov8m.onnx")?;
|
||||
.commit_from_file("yolov8m.onnx")?;
|
||||
```
|
||||
</Step>
|
||||
<Step title="Perform inference">
|
||||
|
||||
@@ -23,9 +23,17 @@ ort::init()
|
||||
-// v1.x
|
||||
-let session = SessionBuilder::new(&environment)?.with_model_from_file("model.onnx")?;
|
||||
+// v2
|
||||
+let session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
+let session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
```
|
||||
|
||||
### `SessionBuilder::with_model_*` -> `SessionBuilder::commit_*`
|
||||
The final `SessionBuilder` methods have been renamed for clarity.
|
||||
|
||||
- `SessionBuilder::with_model_from_file` -> `SessionBuilder::commit_from_file`
|
||||
- `SessionBuilder::with_model_from_memory` -> `SessionBuilder::commit_from_memory`
|
||||
- `SessionBuilder::with_model_from_memory_directly` -> `SessionBuilder::commit_from_memory_directly`
|
||||
- `SessionBuilder::with_model_downloaded` -> `SessionBuilder::commit_from_url`
|
||||
|
||||
## Session inputs
|
||||
|
||||
### `CowArray`/`IxDyn`/`ndarray` no longer required
|
||||
@@ -147,7 +155,7 @@ The dependency on `ndarray` is now declared optional. If you use `ort` with `def
|
||||
## Model Zoo structs have been removed
|
||||
ONNX pushed a new Model Zoo structure that adds hundreds of different models. This is impractical to maintain, so the built-in structs have been removed.
|
||||
|
||||
You can still use `Session::with_model_downloaded`, it just now takes a URL string instead of a struct.
|
||||
You can still use `Session::commit_from_url`, it just now takes a URL string instead of a struct.
|
||||
|
||||
## Changes to logging
|
||||
Environment-level logging configuration (i.e. `EnvironmentBuilder::with_log_level`) has been removed because it could cause unnecessary confusion with our `tracing` integration.
|
||||
|
||||
@@ -50,7 +50,7 @@ use ort::{CUDAExecutionProvider, Session};
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_model_from_file("model.onnx")?;
|
||||
.commit_from_file("model.onnx")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -72,7 +72,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Or use ANE on Apple platforms
|
||||
CoreMLExecutionProvider::default().build()
|
||||
])?
|
||||
.with_model_from_file("model.onnx")?;
|
||||
.commit_from_file("model.onnx")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -94,7 +94,7 @@ fn main() -> anyhow::Result<()> {
|
||||
.with_ane_only()
|
||||
.build()
|
||||
])?
|
||||
.with_model_from_file("model.onnx")?;
|
||||
.commit_from_file("model.onnx")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -117,7 +117,7 @@ fn main() -> anyhow::Result<()> {
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let session = builder.with_model_from_file("model.onnx")?;
|
||||
let session = builder.commit_from_file("model.onnx")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -137,7 +137,7 @@ fn main() -> anyhow::Result<()> {
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let session = builder.with_model_from_file("model.onnx")?;
|
||||
let session = builder.commit_from_file("model.onnx")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -154,7 +154,7 @@ fn main() -> anyhow::Result<()> {
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.commit()?;
|
||||
|
||||
let session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
let session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
// The session will attempt to register the CUDA EP
|
||||
// since we configured the environment default.
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ impl Kernel for CustomOpTwoKernel {
|
||||
fn main() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
|
||||
.with_model_from_file("tests/data/custom_op_test.onnx")?;
|
||||
.commit_from_file("tests/data/custom_op_test.onnx")?;
|
||||
|
||||
let values = session.run(ort::inputs![Array2::<f32>::zeros((3, 5)), Array2::<f32>::ones((3, 5))]?)?;
|
||||
println!("{:?}", values[0].extract_tensor::<i32>()?);
|
||||
|
||||
@@ -36,7 +36,7 @@ fn main() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;
|
||||
|
||||
// Load the tokenizer and encode the prompt into a sequence of tokens.
|
||||
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();
|
||||
|
||||
@@ -36,7 +36,7 @@ fn main() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;
|
||||
|
||||
// Load the tokenizer and encode the prompt into a sequence of tokens.
|
||||
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();
|
||||
|
||||
@@ -45,7 +45,7 @@ fn main() -> ort::Result<()> {
|
||||
process::exit(0);
|
||||
};
|
||||
|
||||
let session = Session::builder()?.with_model_from_file(path)?;
|
||||
let session = Session::builder()?.commit_from_file(path)?;
|
||||
|
||||
let meta = session.metadata()?;
|
||||
if let Ok(x) = meta.name() {
|
||||
|
||||
@@ -16,7 +16,7 @@ fn main() -> ort::Result<()> {
|
||||
.commit()?;
|
||||
|
||||
let model =
|
||||
Session::builder()?.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;
|
||||
Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;
|
||||
|
||||
let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap();
|
||||
let (img_width, img_height) = (original_img.width(), original_img.height());
|
||||
|
||||
@@ -59,7 +59,7 @@ fn main() -> ort::Result<()> {
|
||||
input[[0, 2, y, x]] = (b as f32) / 255.;
|
||||
}
|
||||
|
||||
let model = Session::builder()?.with_model_downloaded(YOLOV8M_URL)?;
|
||||
let model = Session::builder()?.commit_from_url(YOLOV8M_URL)?;
|
||||
|
||||
// Run YOLOv8 inference
|
||||
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?;
|
||||
|
||||
32
src/lib.rs
32
src/lib.rs
@@ -106,23 +106,21 @@ pub(crate) fn dylib_path() -> &'static String {
|
||||
pub(crate) fn lib_handle() -> MutexGuard<'static, libloading::Library> {
|
||||
G_ORT_LIB
|
||||
.get_or_init(|| {
|
||||
unsafe {
|
||||
// resolve path relative to executable
|
||||
let path: std::path::PathBuf = dylib_path().into();
|
||||
let absolute_path = if path.is_absolute() {
|
||||
path
|
||||
} else {
|
||||
let relative = std::env::current_exe()
|
||||
.expect("could not get current executable path")
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join(&path);
|
||||
if relative.exists() { relative } else { path }
|
||||
};
|
||||
let lib = libloading::Library::new(&absolute_path)
|
||||
.unwrap_or_else(|e| panic!("An error occurred while attempting to load the ONNX Runtime binary at `{}`: {e}", absolute_path.display()));
|
||||
Arc::new(Mutex::new(lib))
|
||||
}
|
||||
// resolve path relative to executable
|
||||
let path: std::path::PathBuf = dylib_path().into();
|
||||
let absolute_path = if path.is_absolute() {
|
||||
path
|
||||
} else {
|
||||
let relative = std::env::current_exe()
|
||||
.expect("could not get current executable path")
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join(&path);
|
||||
if relative.exists() { relative } else { path }
|
||||
};
|
||||
let lib = unsafe { libloading::Library::new(&absolute_path) }
|
||||
.unwrap_or_else(|e| panic!("An error occurred while attempting to load the ONNX Runtime binary at `{}`: {e}", absolute_path.display()));
|
||||
Arc::new(Mutex::new(lib))
|
||||
})
|
||||
.lock()
|
||||
.expect("failed to acquire ONNX Runtime dylib lock; another thread panicked?")
|
||||
|
||||
@@ -49,7 +49,7 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
|
||||
/// # use ndarray::Array1;
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// # let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
/// let _ = session.run(ort::inputs![Array1::from_vec(vec![1, 2, 3, 4, 5])]?);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
@@ -62,7 +62,7 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
|
||||
/// # use ndarray::Array1;
|
||||
/// # use ort::{GraphOptimizationLevel, Session, Value};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// # let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
/// let _ = session
|
||||
/// .run(ort::inputs![Value::from_string_array(session.allocator(), Array1::from_vec(vec!["hello", "world"]))?]?);
|
||||
/// # Ok(())
|
||||
@@ -76,7 +76,7 @@ impl<'i, const N: usize> From<[Value; N]> for SessionInputs<'i, N> {
|
||||
/// # use ndarray::Array1;
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// # let mut session = Session::builder()?.with_model_from_file("model.onnx")?;
|
||||
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
/// let _ = session.run(ort::inputs! {
|
||||
/// "tokens" => Array1::from_vec(vec![1, 2, 3, 4, 5])
|
||||
/// }?);
|
||||
|
||||
@@ -40,7 +40,7 @@ pub use self::{input::SessionInputs, output::SessionOutputs};
|
||||
|
||||
/// Creates a session using the builder pattern.
|
||||
///
|
||||
/// Once configured, use the [`SessionBuilder::with_model_from_file`](crate::SessionBuilder::with_model_from_file)
|
||||
/// Once configured, use the [`SessionBuilder::commit_from_file`](crate::SessionBuilder::commit_from_file)
|
||||
/// method to 'commit' the builder configuration into a [`Session`].
|
||||
///
|
||||
/// ```
|
||||
@@ -49,7 +49,7 @@ pub use self::{input::SessionInputs, output::SessionOutputs};
|
||||
/// let session = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
/// .with_intra_threads(1)?
|
||||
/// .with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -99,7 +99,7 @@ impl SessionBuilder {
|
||||
/// let session = Session::builder()?
|
||||
/// .with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
/// .with_intra_threads(1)?
|
||||
/// .with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// .commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -254,7 +254,7 @@ impl SessionBuilder {
|
||||
/// Downloads a pre-trained ONNX model from the given URL and builds the session.
|
||||
#[cfg(feature = "fetch-models")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
|
||||
pub fn with_model_downloaded(self, model_url: impl AsRef<str>) -> Result<Session> {
|
||||
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
|
||||
let mut download_dir = ort_sys::internal::dirs::cache_dir()
|
||||
.expect("could not determine cache directory")
|
||||
.join("models");
|
||||
@@ -294,11 +294,11 @@ impl SessionBuilder {
|
||||
}
|
||||
};
|
||||
|
||||
self.with_model_from_file(downloaded_path)
|
||||
self.commit_from_file(downloaded_path)
|
||||
}
|
||||
|
||||
/// Loads an ONNX model from a file and builds the session.
|
||||
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> Result<Session>
|
||||
pub fn commit_from_file<P>(self, model_filepath_ref: P) -> Result<Session>
|
||||
where
|
||||
P: AsRef<Path>
|
||||
{
|
||||
@@ -374,7 +374,7 @@ impl SessionBuilder {
|
||||
///
|
||||
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that
|
||||
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
|
||||
pub fn with_model_from_memory_directly(self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
|
||||
pub fn commit_from_memory_directly(self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
|
||||
let str_to_char = |s: &str| {
|
||||
s.as_bytes()
|
||||
.iter()
|
||||
@@ -386,13 +386,13 @@ impl SessionBuilder {
|
||||
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), str_to_char("session.use_ort_model_bytes_directly").as_ptr(), str_to_char("1").as_ptr())];
|
||||
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), str_to_char("session.use_ort_model_bytes_for_initializers").as_ptr(), str_to_char("1").as_ptr())];
|
||||
|
||||
let session = self.with_model_from_memory(model_bytes)?;
|
||||
let session = self.commit_from_memory(model_bytes)?;
|
||||
|
||||
Ok(InMemorySession { session, phantom: PhantomData })
|
||||
}
|
||||
|
||||
/// Load an ONNX graph from memory and commit the session.
|
||||
pub fn with_model_from_memory(self, model_bytes: &[u8]) -> Result<Session> {
|
||||
pub fn commit_from_memory(self, model_bytes: &[u8]) -> Result<Session> {
|
||||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
|
||||
|
||||
let env = get_environment()?;
|
||||
@@ -479,7 +479,7 @@ impl Drop for SharedSessionInner {
|
||||
/// ```
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
|
||||
/// let outputs = session.run(ort::inputs![input]?)?;
|
||||
/// # Ok(())
|
||||
@@ -566,7 +566,7 @@ impl RunOptions {
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
|
||||
/// let run_options = Arc::new(RunOptions::new()?);
|
||||
///
|
||||
@@ -595,7 +595,7 @@ impl RunOptions {
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
|
||||
/// let run_options = Arc::new(RunOptions::new()?);
|
||||
///
|
||||
@@ -662,7 +662,7 @@ impl Session {
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
|
||||
/// let outputs = session.run(ort::inputs![input]?)?;
|
||||
/// # Ok(())
|
||||
@@ -690,7 +690,7 @@ impl Session {
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
|
||||
/// let run_options = Arc::new(RunOptions::new()?);
|
||||
///
|
||||
@@ -829,7 +829,7 @@ fn close_lib_handle(handle: *mut std::os::raw::c_void) {
|
||||
|
||||
/// This module contains dangerous functions working on raw pointers.
|
||||
/// Those functions are only to be used from inside the
|
||||
/// `SessionBuilder::with_model_from_file()` method.
|
||||
/// `SessionBuilder::commit_from_file()` method.
|
||||
mod dangerous {
|
||||
use super::*;
|
||||
use crate::value::{extract_data_type_from_map_info, extract_data_type_from_sequence_info, extract_data_type_from_tensor_info};
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{Allocator, Value};
|
||||
/// ```
|
||||
/// # use ort::{GraphOptimizationLevel, Session};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
|
||||
/// let outputs = session.run(ort::inputs![input]?)?;
|
||||
///
|
||||
|
||||
@@ -35,11 +35,6 @@ pub enum TensorElementType {
|
||||
Uint32,
|
||||
/// Unsigned 64-bit integer, equivalent to Rust's `u64`.
|
||||
Uint64,
|
||||
// /// Complex 64-bit floating point number, equivalent to Rust's `num_complex::Complex<f64>`.
|
||||
// Complex64,
|
||||
// TODO: `num_complex` crate doesn't support i128 provided by the `decimal` crate.
|
||||
// /// Complex 128-bit floating point number, equivalent to Rust's `num_complex::Complex<f128>`.
|
||||
// Complex128,
|
||||
/// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature).
|
||||
#[cfg(feature = "half")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
|
||||
@@ -63,8 +58,6 @@ impl From<TensorElementType> for ort_sys::ONNXTensorElementDataType {
|
||||
TensorElementType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
|
||||
TensorElementType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
|
||||
TensorElementType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
|
||||
// TensorElementDataType::Complex64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64,
|
||||
// TensorElementDataType::Complex128 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128,
|
||||
#[cfg(feature = "half")]
|
||||
TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
|
||||
}
|
||||
@@ -87,8 +80,6 @@ impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => TensorElementType::Float64,
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => TensorElementType::Uint32,
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => TensorElementType::Uint64,
|
||||
// ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => TensorElementDataType::Complex64,
|
||||
// ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => TensorElementDataType::Complex128,
|
||||
#[cfg(feature = "half")]
|
||||
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => TensorElementType::Bfloat16,
|
||||
_ => panic!("Invalid ONNXTensorElementDataType value")
|
||||
@@ -126,18 +117,11 @@ impl_type_trait!(half::f16, Float16);
|
||||
impl_type_trait!(f64, Float64);
|
||||
impl_type_trait!(u32, Uint32);
|
||||
impl_type_trait!(u64, Uint64);
|
||||
// impl_type_trait!(num_complex::Complex<f64>, Complex64);
|
||||
// impl_type_trait!(num_complex::Complex<f128>, Complex128);
|
||||
#[cfg(feature = "half")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
|
||||
impl_type_trait!(half::bf16, Bfloat16);
|
||||
|
||||
/// Adapter for common Rust string types to ONNX strings.
|
||||
///
|
||||
/// It should be easy to use both [`String`] and `&str` as [`TensorElementDataType::String`] data, but
|
||||
/// we can't define an automatic implementation for anything that implements [`AsRef<str>`] as it
|
||||
/// would conflict with the implementations of [`IntoTensorElementDataType`] for primitive numeric
|
||||
/// types (which might implement [`AsRef<str>`] at some point in the future).
|
||||
pub trait Utf8Data {
|
||||
/// Returns the contents of this value as a slice of UTF-8 bytes.
|
||||
fn as_utf8_bytes(&self) -> &[u8];
|
||||
|
||||
@@ -155,7 +155,7 @@ impl Value {
|
||||
/// ```
|
||||
/// # use ort::{Session, Value};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.with_model_from_file("tests/data/vectorizer.onnx")?;
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?;
|
||||
/// // You'll need to obtain an `Allocator` from a session in order to create string tensors.
|
||||
/// let allocator = session.allocator();
|
||||
///
|
||||
@@ -173,7 +173,7 @@ impl Value {
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Note that string data will always be copied, no matter what data is provided.
|
||||
/// Note that string data will *always* be copied, no matter what form the data is provided in.
|
||||
pub fn from_string_array<T: Utf8Data>(allocator: &Allocator, input: impl IntoValueTensor<Item = T>) -> Result<Value> {
|
||||
let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?;
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ use self::impl_tensor::ToDimensions;
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{Session, Value, ValueType, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// // `ValueType`s can be obtained from session inputs/outputs:
|
||||
/// let input = &session.inputs[0];
|
||||
/// assert_eq!(
|
||||
@@ -227,7 +227,7 @@ impl<'v> DerefMut for ValueRefMut<'v> {
|
||||
/// ```
|
||||
/// # use ort::{Session, Value, ValueType, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let upsample = Session::builder()?.with_model_from_file("tests/data/upsample.onnx")?;
|
||||
/// # let upsample = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// // Create a value from a raw data vector
|
||||
/// let value = Value::from_array(([1usize, 1, 1, 3], vec![1.0_f32, 2.0, 3.0].into_boxed_slice()))?;
|
||||
///
|
||||
|
||||
@@ -13,7 +13,7 @@ fn mnist_5() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")
|
||||
.expect("Could not download model from file");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
|
||||
@@ -19,7 +19,7 @@ fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx")
|
||||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx")
|
||||
.expect("Could not download model from file");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
|
||||
@@ -51,7 +51,7 @@ fn upsample() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_memory(&session_data)
|
||||
.commit_from_memory(&session_data)
|
||||
.expect("Could not read model from memory");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
@@ -92,7 +92,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_memory_directly(&session_data) // Zero-copy.
|
||||
.commit_from_memory_directly(&session_data) // Zero-copy.
|
||||
.expect("Could not read model from memory");
|
||||
|
||||
assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]);
|
||||
|
||||
@@ -11,7 +11,7 @@ fn vectorizer() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.with_model_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))
|
||||
.commit_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx"))
|
||||
.expect("Could not load model");
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
|
||||
Reference in New Issue
Block a user