mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: usability
aka The Cleanening, part 2 - Add clearer documentation and examples for more things. - Rework string tensors by introducing `PrimitiveTensorElementType` for primitive (i.e. f32) types, and again re-implementing `IntoTensorElementType` for `String`. This allows string tensors to be used via `Tensor<String>` instead of exclusively via `DynTensor`. Additionally, string tensors no longer require an `Allocator` to be created (which didn't make sense, since string data in Rust can only ever be stored on the CPU anyway). This also now applies to `Map`s, since their data also needed to be on the CPU anyway. (`Sequence`s are currently unaffected because I think a custom allocator could be useful for them?) - Rework the `IoBinding` interface, and add an example clarifying the intended usage of it (ref #209). Thanks to AAce from the pyke Discord for pointing out the mutability issue in the old interface, which should be addressed now. - Refactor `OperatorDomain::add` from the slightly-nicer-looking-but-more-confusing `fn<T>(t: T)` to just `fn<T>()` to further enforce the fact that `Operator`s are zero-sized. - Maps can now have `String` keys. - Remove some unused errors.
This commit is contained in:
@@ -82,7 +82,7 @@ impl Kernel for CustomOpTwoKernel {
|
||||
|
||||
fn main() -> ort::Result<()> {
|
||||
let session = Session::builder()?
|
||||
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
|
||||
.with_operators(OperatorDomain::new("test.customop")?.add::<CustomOpOne>()?.add::<CustomOpTwo>()?)?
|
||||
.commit_from_file("tests/data/custom_op_test.onnx")?;
|
||||
|
||||
let values = session.run(ort::inputs![Array2::<f32>::zeros((3, 5)), Array2::<f32>::ones((3, 5))]?)?;
|
||||
|
||||
@@ -34,7 +34,7 @@ pub struct Environment {
|
||||
}
|
||||
|
||||
impl Environment {
|
||||
/// Loads the underlying [`ort_sys::OrtEnv`] pointer.
|
||||
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
|
||||
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
|
||||
self.env_ptr.load(Ordering::Relaxed)
|
||||
}
|
||||
@@ -52,13 +52,14 @@ impl Drop for Environment {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a reference to the global environment, creating one if an environment has been committed yet.
|
||||
/// Gets a reference to the global environment, creating one if an environment has not been
|
||||
/// [`commit`](EnvironmentBuilder::commit)ted yet.
|
||||
pub fn get_environment() -> Result<&'static Arc<Environment>> {
|
||||
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
|
||||
Ok(c)
|
||||
} else {
|
||||
debug!("Environment not yet initialized, creating a new one");
|
||||
EnvironmentBuilder::default().commit()?;
|
||||
EnvironmentBuilder::new().commit()?;
|
||||
|
||||
Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() })
|
||||
}
|
||||
@@ -72,7 +73,7 @@ pub struct EnvironmentGlobalThreadPoolOptions {
|
||||
pub intra_op_thread_affinity: Option<String>
|
||||
}
|
||||
|
||||
/// Struct used to build an `Environment`.
|
||||
/// Struct used to build an [`Environment`]; see [`crate::init`].
|
||||
pub struct EnvironmentBuilder {
|
||||
name: String,
|
||||
telemetry: bool,
|
||||
@@ -80,8 +81,8 @@ pub struct EnvironmentBuilder {
|
||||
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
|
||||
}
|
||||
|
||||
impl Default for EnvironmentBuilder {
|
||||
fn default() -> Self {
|
||||
impl EnvironmentBuilder {
|
||||
pub(crate) fn new() -> Self {
|
||||
EnvironmentBuilder {
|
||||
name: "default".to_string(),
|
||||
telemetry: true,
|
||||
@@ -89,11 +90,9 @@ impl Default for EnvironmentBuilder {
|
||||
global_thread_pool_options: None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EnvironmentBuilder {
|
||||
/// Configure the environment with a given name for logging purposes.
|
||||
#[must_use]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn with_name<S>(mut self, name: S) -> Self
|
||||
where
|
||||
S: Into<String>
|
||||
@@ -102,7 +101,17 @@ impl EnvironmentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
/// Enable or disable sending telemetry events to Microsoft.
|
||||
///
|
||||
/// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled.
|
||||
/// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled.
|
||||
///
|
||||
/// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.0/onnxruntime/core/platform/windows/telemetry.cc).
|
||||
/// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or
|
||||
/// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names,
|
||||
/// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to
|
||||
/// better understand how customers use ONNX Runtime and where performance can be improved.
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn with_telemetry(mut self, enable: bool) -> Self {
|
||||
self.telemetry = enable;
|
||||
self
|
||||
@@ -116,14 +125,14 @@ impl EnvironmentBuilder {
|
||||
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
|
||||
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
|
||||
/// feature enabled will emit a warning.
|
||||
#[must_use]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self {
|
||||
self.execution_providers = execution_providers.as_ref().to_vec();
|
||||
self
|
||||
}
|
||||
|
||||
/// Enables the global thread pool for this environment.
|
||||
#[must_use]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self {
|
||||
self.global_thread_pool_options = Some(options);
|
||||
self
|
||||
@@ -158,14 +167,17 @@ impl EnvironmentBuilder {
|
||||
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
|
||||
}
|
||||
|
||||
ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
|
||||
ortsys![
|
||||
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
|
||||
logging_function,
|
||||
logger_param,
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
||||
cname.as_ptr(),
|
||||
thread_options,
|
||||
&mut env_ptr
|
||||
) -> Error::CreateEnvironment; nonNull(env_ptr)];
|
||||
) -> Error::CreateEnvironment;
|
||||
nonNull(env_ptr)
|
||||
];
|
||||
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
|
||||
(env_ptr, true)
|
||||
} else {
|
||||
@@ -174,13 +186,16 @@ impl EnvironmentBuilder {
|
||||
// FIXME: What should go here?
|
||||
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
|
||||
ortsys![unsafe CreateEnvWithCustomLogger(
|
||||
ortsys![
|
||||
unsafe CreateEnvWithCustomLogger(
|
||||
logging_function,
|
||||
logger_param,
|
||||
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
||||
cname.as_ptr(),
|
||||
&mut env_ptr
|
||||
) -> Error::CreateEnvironment; nonNull(env_ptr)];
|
||||
) -> Error::CreateEnvironment;
|
||||
nonNull(env_ptr)
|
||||
];
|
||||
(env_ptr, false)
|
||||
};
|
||||
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");
|
||||
@@ -205,15 +220,25 @@ impl EnvironmentBuilder {
|
||||
|
||||
/// Creates an ONNX Runtime environment.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::CUDAExecutionProvider;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// ort::init()
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
/// .commit()?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
/// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a
|
||||
/// default environment will be created.
|
||||
/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it.
|
||||
/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it.
|
||||
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
|
||||
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
|
||||
#[must_use]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn init() -> EnvironmentBuilder {
|
||||
EnvironmentBuilder::default()
|
||||
EnvironmentBuilder::new()
|
||||
}
|
||||
|
||||
/// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`)
|
||||
@@ -221,15 +246,26 @@ pub fn init() -> EnvironmentBuilder {
|
||||
///
|
||||
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ort::CUDAExecutionProvider;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
|
||||
/// ort::init_from(lib_path.join("onnxruntime.dll"))
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
/// .commit()?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
|
||||
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
|
||||
#[cfg(feature = "load-dynamic")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
|
||||
#[must_use]
|
||||
#[must_use = "commit() must be called in order for the environment to take effect"]
|
||||
pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
|
||||
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
|
||||
EnvironmentBuilder::default()
|
||||
EnvironmentBuilder::new()
|
||||
}
|
||||
|
||||
/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct.
|
||||
@@ -325,7 +361,7 @@ mod tests {
|
||||
assert!(!is_env_initialized());
|
||||
assert_eq!(env_ptr(), None);
|
||||
|
||||
EnvironmentBuilder::default().with_name("env_is_initialized").commit()?;
|
||||
EnvironmentBuilder::new().with_name("env_is_initialized").commit()?;
|
||||
assert!(is_env_initialized());
|
||||
assert_ne!(env_ptr(), None);
|
||||
Ok(())
|
||||
|
||||
61
src/error.rs
61
src/error.rs
@@ -121,9 +121,6 @@ pub enum Error {
|
||||
/// Error occurred when filling a tensor with string data
|
||||
#[error("Failed to fill string tensor: {0}")]
|
||||
FillStringTensor(ErrorInternal),
|
||||
/// Error occurred when checking if a value is a tensor
|
||||
#[error("Failed to check if value is a tensor: {0}")]
|
||||
FailedTensorCheck(ErrorInternal),
|
||||
/// Error occurred when getting tensor type and shape
|
||||
#[error("Failed to get tensor type and shape: {0}")]
|
||||
GetTensorTypeAndShape(ErrorInternal),
|
||||
@@ -159,12 +156,6 @@ pub enum Error {
|
||||
/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models).
|
||||
#[error("Failed to download ONNX model: {0}")]
|
||||
DownloadError(#[from] FetchModelError),
|
||||
/// Type of input data and the ONNX model do not match.
|
||||
#[error("Data types do not match: expected {model:?}, got {input:?}")]
|
||||
NonMatchingDataTypes { input: TensorElementType, model: TensorElementType },
|
||||
/// Dimensions of input data and the ONNX model do not match.
|
||||
#[error("Dimensions do not match: {0:?}")]
|
||||
NonMatchingDimensions(NonMatchingDimensionsError),
|
||||
/// File does not exist
|
||||
#[error("File `{filename:?}` does not exist")]
|
||||
FileDoesNotExist {
|
||||
@@ -186,9 +177,6 @@ pub enum Error {
|
||||
/// ORT pointer should not have been null
|
||||
#[error("`{0}` should not be a null pointer")]
|
||||
PointerShouldNotBeNull(&'static str),
|
||||
/// The runtime type was undefined.
|
||||
#[error("Undefined tensor element type")]
|
||||
UndefinedTensorElementType,
|
||||
/// Could not retrieve model metadata.
|
||||
#[error("Failed to retrieve model metadata: {0}")]
|
||||
GetModelMetadata(ErrorInternal),
|
||||
@@ -208,8 +196,8 @@ pub enum Error {
|
||||
ExecutionProviderNotRegistered(&'static str),
|
||||
#[error("Expected tensor to be on CPU in order to get data, but had allocation device `{0}`.")]
|
||||
TensorNotOnCpu(&'static str),
|
||||
#[error("String tensors require the session's allocator to be provided through `Value::from_array`.")]
|
||||
StringTensorRequiresAllocator,
|
||||
#[error("Cannot extract scalar value from a {0}-dimensional tensor")]
|
||||
TensorNot0Dimensional(usize),
|
||||
#[error("Failed to create memory info: {0}")]
|
||||
CreateMemoryInfo(ErrorInternal),
|
||||
#[error("Could not get allocation device from `MemoryInfo`: {0}")]
|
||||
@@ -222,10 +210,10 @@ pub enum Error {
|
||||
BindInput(ErrorInternal),
|
||||
#[error("Error when binding output: {0}")]
|
||||
BindOutput(ErrorInternal),
|
||||
#[error("Failed to clear IO binding: {0}")]
|
||||
ClearBinding(ErrorInternal),
|
||||
#[error("Error when retrieving session outputs from `IoBinding`: {0}")]
|
||||
GetBoundOutputs(ErrorInternal),
|
||||
#[error("Cannot use `extract_tensor` on a value that is {0:?}")]
|
||||
NotTensor(ValueType),
|
||||
#[error("Cannot use `extract_sequence` on a value that is {0:?}")]
|
||||
NotSequence(ValueType),
|
||||
#[error("Cannot use `extract_map` on a value that is {0:?}")]
|
||||
@@ -252,6 +240,8 @@ pub enum Error {
|
||||
GetOperatorInput(ErrorInternal),
|
||||
#[error("Failed to get operator output: {0}")]
|
||||
GetOperatorOutput(ErrorInternal),
|
||||
#[error("Failed to retrieve GPU compute stream from kernel context: {0}")]
|
||||
GetOperatorGPUComputeStream(ErrorInternal),
|
||||
#[error("{0}")]
|
||||
CustomError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("String tensors cannot be borrowed as mutable")]
|
||||
@@ -266,37 +256,20 @@ pub enum Error {
|
||||
GetDeviceId(ErrorInternal)
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
fn from(_: Infallible) -> Self {
|
||||
Error::Infallible
|
||||
impl Error {
|
||||
/// Wrap a custom, user-provided error in an [`ort::Error`](Error). The resulting error will be the
|
||||
/// [`Error::CustomError`] variant.
|
||||
///
|
||||
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
|
||||
/// related operation fails.
|
||||
pub fn wrap<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
|
||||
Error::CustomError(Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Error used when the input dimensions defined in the model and passed from an inference call do not match.
|
||||
#[non_exhaustive]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum NonMatchingDimensionsError {
|
||||
/// Number of inputs from model does not match the number of inputs from inference call.
|
||||
#[error(
|
||||
"Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})"
|
||||
)]
|
||||
InputsCount {
|
||||
/// Number of input dimensions used by inference call
|
||||
inference_input_count: usize,
|
||||
/// Number of input dimensions defined in model
|
||||
model_input_count: usize,
|
||||
/// Input dimensions used by inference call
|
||||
inference_input: Vec<Vec<usize>>,
|
||||
/// Input dimensions defined in model
|
||||
model_input: Vec<Vec<Option<u32>>>
|
||||
},
|
||||
/// Inputs length from model does not match the expected input from inference call
|
||||
#[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")]
|
||||
InputsLength {
|
||||
/// Input dimensions used by inference call
|
||||
inference_input: Vec<Vec<usize>>,
|
||||
/// Input dimensions defined in model
|
||||
model_input: Vec<Vec<Option<u32>>>
|
||||
impl From<Infallible> for Error {
|
||||
fn from(_: Infallible) -> Self {
|
||||
Error::Infallible
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ffi::CString,
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
};
|
||||
@@ -9,24 +11,87 @@ use crate::{
|
||||
memory::MemoryInfo,
|
||||
ortsys,
|
||||
session::{output::SessionOutputs, RunOptions},
|
||||
value::{Value, ValueRefMut},
|
||||
Error, Result, Session, ValueTypeMarker
|
||||
value::{Value, ValueInner},
|
||||
DynValue, Error, Result, Session, ValueTypeMarker
|
||||
};
|
||||
|
||||
/// Enables binding of session inputs and/or outputs to pre-allocated memory.
|
||||
///
|
||||
/// Note that this arrangement is designed to minimize data copies, and to that effect, your memory allocations must
|
||||
/// match what is expected by the model, whether you run on CPU or GPU. Data will still be copied if the
|
||||
/// pre-allocated memory location does not match the one expected by the model. However, copies with `IoBinding`s are
|
||||
/// only done once, at the time of the binding, not at run time. This means, that if your input data required a copy,
|
||||
/// your further input modifications would not be seen by ONNX Runtime unless you rebind it, even if it is the same
|
||||
/// buffer. If your scenario requires that the data is copied, `IoBinding` may not be the best match for your use case.
|
||||
/// The fact that data copy is not made during runtime may also have performance implications.
|
||||
/// [`IoBinding`] minimizes copies between a device (like a GPU) and the host (CPU) by allowing the user to bind a
|
||||
/// certain input/output to a pre-allocated value on a specific device.
|
||||
///
|
||||
/// [`IoBinding`] is most suitable for:
|
||||
/// - An ensemble of models in which the output from one model is the input of another and does not need to pass through
|
||||
/// the CPU to perform additional processing.
|
||||
/// - Situations where the output should stay on a device (e.g. to perform additional processing with CUDA).
|
||||
/// - Diffusion models, for instance, that accept an unchanging embedding for conditioning.
|
||||
///
|
||||
/// [`IoBinding`] will not provide any meaningful benefit for:
|
||||
/// - Models where every input changes with each invocation, such as a causal language model or object recognition
|
||||
/// model.
|
||||
/// - Pipelines that go straight from CPU -> GPU -> CPU.
|
||||
///
|
||||
/// # Example
|
||||
/// A diffusion model which takes a text condition input.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ort::{Allocator, AllocatorType, AllocationDevice, CUDAExecutionProvider, MemoryInfo, MemoryType, Session, Tensor, IoBinding};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let text_encoder = Session::builder()?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
/// .commit_from_file("text_encoder.onnx")?;
|
||||
/// let unet = Session::builder()?
|
||||
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
/// .commit_from_file("unet.onnx")?;
|
||||
///
|
||||
/// let text_condition = text_encoder
|
||||
/// .run(ort::inputs![Tensor::<i64>::from_array((
|
||||
/// vec![27],
|
||||
/// vec![
|
||||
/// 23763, 15460, 473, 68, 312, 265, 17463, 4098, 304, 1077, 283, 198, 7676, 5976, 272, 285, 3609, 435,
|
||||
/// 21680, 321, 265, 300, 1689, 64, 285, 4763, 64
|
||||
/// ]
|
||||
/// ))?]?)?
|
||||
/// .remove("output0")
|
||||
/// .unwrap();
|
||||
///
|
||||
/// let input_allocator = Allocator::new(
|
||||
/// &unet,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
|
||||
/// )?;
|
||||
/// let mut latents = Tensor::<f32>::new(&input_allocator, [1, 4, 64, 64])?;
|
||||
///
|
||||
/// let mut io_binding = unet.create_binding()?;
|
||||
/// io_binding.bind_input("condition", &text_condition)?;
|
||||
///
|
||||
/// let output_allocator = Allocator::new(
|
||||
/// &unet,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUOutput)?
|
||||
/// )?;
|
||||
/// io_binding.bind_output("noise_pred", Tensor::<f32>::new(&output_allocator, [1, 4, 64, 64])?)?;
|
||||
///
|
||||
/// for _ in 0..20 {
|
||||
/// io_binding.bind_input("latents", &latents)?;
|
||||
/// let noise_pred = io_binding.run()?.remove("noise_pred").unwrap();
|
||||
///
|
||||
/// let mut latents = latents.extract_tensor_mut();
|
||||
/// latents += &noise_pred.try_extract_tensor::<f32>()?;
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// [`IoBinding`] may provide a decent speedup in this example since the `condition` tensor is unchanging between runs.
|
||||
/// If we were to use normal session inference, the `condition` tensor would be needlessly copied with each invocation
|
||||
/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition`
|
||||
/// tensor is only copied to the device once instead of 20 times.
|
||||
#[derive(Debug)]
|
||||
pub struct IoBinding<'s> {
|
||||
pub(crate) ptr: NonNull<ort_sys::OrtIoBinding>,
|
||||
session: &'s Session,
|
||||
output_names: Vec<String>
|
||||
held_inputs: HashMap<String, Arc<ValueInner>>,
|
||||
output_names: Vec<String>,
|
||||
output_values: HashMap<String, DynValue>
|
||||
}
|
||||
|
||||
impl<'s> IoBinding<'s> {
|
||||
@@ -36,25 +101,47 @@ impl<'s> IoBinding<'s> {
|
||||
Ok(Self {
|
||||
ptr: unsafe { NonNull::new_unchecked(ptr) },
|
||||
session,
|
||||
output_names: Vec::new()
|
||||
held_inputs: HashMap::new(),
|
||||
output_names: Vec::new(),
|
||||
output_values: HashMap::new()
|
||||
})
|
||||
}
|
||||
|
||||
/// Bind a [`Value`] to a session input.
|
||||
pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &'i mut Value<T>) -> Result<ValueRefMut<'i, T>> {
|
||||
///
|
||||
/// Upon invocation, the value's data will be copied to the device the session is allocated on. The copied data will
|
||||
/// be used as an input (specified by `name`) in all future invocations of [`IoBinding::run`] until the input is
|
||||
/// overridden (by calling [`IoBinding::bind_input`] again) or until all inputs are cleared (via
|
||||
/// [`IoBinding::clear_inputs`] or [`IoBinding::clear`]).
|
||||
///
|
||||
/// The data is only copied **once**, immediately upon invocation of this function. Any changes to the given
|
||||
/// value afterwards will not affect the data seen by the session until the value is re-bound. Subsequent re-binds
|
||||
/// will still copy data, hence why [`IoBinding`] is really only suitable when one or more inputs do not change
|
||||
/// between runs.
|
||||
pub fn bind_input<T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
|
||||
let name = name.as_ref();
|
||||
let cname = CString::new(name)?;
|
||||
ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput];
|
||||
Ok(ort_value.view_mut())
|
||||
self.held_inputs.insert(name.to_string(), Arc::clone(&ort_value.inner));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bind a session output to a pre-allocated [`Value`].
|
||||
pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &'o mut Value<T>) -> Result<ValueRefMut<'o, T>> {
|
||||
///
|
||||
/// This allows for the pre-allocation and reuse of memory in the session output (see [`crate::Tensor::new`]). Any
|
||||
/// subsequent runs via [`IoBinding::run`] will reuse the same tensor to store the output instead of creating a new
|
||||
/// one each time.
|
||||
///
|
||||
/// The output will be accessible in the value returned by [`IoBinding::run`], under the name specified by `name`.
|
||||
pub fn bind_output<T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
|
||||
let name = name.as_ref();
|
||||
let cname = CString::new(name)?;
|
||||
ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput];
|
||||
self.output_names.push(name.to_string());
|
||||
Ok(ort_value.view_mut())
|
||||
// Clear the old bound output if we have any.
|
||||
drop(self.output_values.remove(name));
|
||||
self.output_values.insert(name.to_string(), ort_value.into_dyn());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bind a session output to a device which is specified by `mem_info`.
|
||||
@@ -66,15 +153,35 @@ impl<'s> IoBinding<'s> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn run<'i: 's>(&'i self) -> Result<SessionOutputs<'s>> {
|
||||
/// Clears all bound inputs specified by [`IoBinding::bind_input`].
|
||||
pub fn clear_inputs(&mut self) {
|
||||
ortsys![unsafe ClearBoundInputs(self.ptr.as_ptr())];
|
||||
drop(self.held_inputs.drain());
|
||||
}
|
||||
/// Clears all bound outputs specified by [`IoBinding::bind_output`] or [`IoBinding::bind_output_to_device`].
|
||||
pub fn clear_outputs(&mut self) {
|
||||
ortsys![unsafe ClearBoundOutputs(self.ptr.as_ptr())];
|
||||
drop(self.output_names.drain(..));
|
||||
drop(self.output_values.drain());
|
||||
}
|
||||
/// Clears both the bound inputs & outputs; equivalent to [`IoBinding::clear_inputs`] followed by
|
||||
/// [`IoBinding::clear_outputs`].
|
||||
pub fn clear(&mut self) {
|
||||
self.clear_inputs();
|
||||
self.clear_outputs();
|
||||
}
|
||||
|
||||
/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
|
||||
pub fn run(&mut self) -> Result<SessionOutputs<'_>> {
|
||||
self.run_inner(None)
|
||||
}
|
||||
|
||||
pub fn run_with_options<'i: 's>(&'i self, run_options: Arc<RunOptions>) -> Result<SessionOutputs<'s>> {
|
||||
/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
|
||||
pub fn run_with_options(&mut self, run_options: Arc<RunOptions>) -> Result<SessionOutputs<'_>> {
|
||||
self.run_inner(Some(run_options))
|
||||
}
|
||||
|
||||
fn run_inner<'i: 's>(&'i self, run_options: Option<Arc<RunOptions>>) -> Result<SessionOutputs<'s>> {
|
||||
fn run_inner(&mut self, run_options: Option<Arc<RunOptions>>) -> Result<SessionOutputs<'_>> {
|
||||
let run_options_ptr = if let Some(run_options) = run_options {
|
||||
run_options.run_options_ptr.as_ptr()
|
||||
} else {
|
||||
@@ -82,6 +189,7 @@ impl<'s> IoBinding<'s> {
|
||||
};
|
||||
ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr()) -> Error::SessionRunWithIoBinding];
|
||||
|
||||
let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc<ValueInner>> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect();
|
||||
let mut count = self.output_names.len() as ort_sys::size_t;
|
||||
if count > 0 {
|
||||
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
|
||||
@@ -91,10 +199,17 @@ impl<'s> IoBinding<'s> {
|
||||
let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() }
|
||||
.into_iter()
|
||||
.map(|v| unsafe {
|
||||
Value::from_ptr(
|
||||
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
|
||||
Some(Arc::clone(&self.session.inner))
|
||||
)
|
||||
if let Some(inner) = owned_ptrs.get(&v) {
|
||||
DynValue {
|
||||
inner: Arc::clone(*inner),
|
||||
_markers: PhantomData
|
||||
}
|
||||
} else {
|
||||
DynValue::from_ptr(
|
||||
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
|
||||
Some(Arc::clone(&self.session.inner))
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
// output values will be freed when the `Value`s in `SessionOutputs` drop
|
||||
|
||||
39
src/lib.rs
39
src/lib.rs
@@ -65,7 +65,7 @@ pub use self::session::{
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub use self::tensor::ArrayExtensions;
|
||||
pub use self::tensor::{IntoTensorElementType, TensorElementType};
|
||||
pub use self::tensor::{IntoTensorElementType, Utf8Data, PrimitiveTensorElementType, TensorElementType};
|
||||
pub use self::value::{
|
||||
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
|
||||
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,
|
||||
@@ -143,6 +143,23 @@ pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::ne
|
||||
/// May panic if:
|
||||
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
|
||||
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
|
||||
///
|
||||
/// # Examples
|
||||
/// The primary (public-facing) use case for this function is accessing APIs that do not have a corresponding safe
|
||||
/// implementation in `ort`. For example, [`GetBuildInfoString`](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0a7dba37b0017c0ef3a0ab4e266a967d):
|
||||
///
|
||||
/// ```
|
||||
/// # use std::ffi::CStr;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let api = ort::api().as_ptr();
|
||||
/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) };
|
||||
/// println!("{}", build_info.to_string_lossy());
|
||||
/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html).
|
||||
pub fn api() -> NonNull<ort_sys::OrtApi> {
|
||||
unsafe {
|
||||
NonNull::new_unchecked(
|
||||
@@ -252,6 +269,26 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {
|
||||
.map_err(Error::FfiStringConversion)
|
||||
}
|
||||
|
||||
pub(crate) struct PrivateTraitMarker;
|
||||
|
||||
macro_rules! private_trait {
|
||||
() => {
|
||||
#[doc(hidden)]
|
||||
#[allow(private_interfaces)]
|
||||
fn _private() -> crate::PrivateTraitMarker;
|
||||
};
|
||||
}
|
||||
macro_rules! private_impl {
|
||||
() => {
|
||||
#[allow(private_interfaces)]
|
||||
fn _private() -> crate::PrivateTraitMarker {
|
||||
crate::PrivateTraitMarker
|
||||
}
|
||||
};
|
||||
}
|
||||
pub(crate) use private_impl;
|
||||
pub(crate) use private_trait;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::ffi::CString;
|
||||
|
||||
206
src/memory.rs
206
src/memory.rs
@@ -1,20 +1,75 @@
|
||||
use std::{
|
||||
ffi::{c_char, c_int, CString},
|
||||
ptr::NonNull
|
||||
ptr::NonNull,
|
||||
sync::Arc
|
||||
};
|
||||
|
||||
use super::{
|
||||
error::{Error, Result},
|
||||
ortsys
|
||||
};
|
||||
use crate::{char_p_to_string, error::status_to_result, Session};
|
||||
use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner};
|
||||
|
||||
/// An ONNX Runtime allocator, used to manage the allocation of [`crate::Value`]s.
|
||||
/// A device allocator used to manage the allocation of [`crate::Value`]s.
|
||||
///
|
||||
/// # Direct allocation
|
||||
/// [`Allocator`] can be used to directly allocate device memory. This can be useful if you have a
|
||||
/// postprocessing step that runs on the GPU.
|
||||
/// ```no_run
|
||||
/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let allocator = Allocator::new(
|
||||
/// &session,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
|
||||
/// )?;
|
||||
///
|
||||
/// let mut tensor = Tensor::<f32>::new(&allocator, [1, 3, 224, 224])?;
|
||||
/// // Here, `data_ptr` is a pointer to **device memory** inaccessible to the CPU; we'll need another crate, like
|
||||
/// // `cudarc`, to access it.
|
||||
/// let data_ptr = tensor.data_ptr_mut()?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Note that `ort` does not facilitate the transfer of data between host & device outside of session inputs &
|
||||
/// outputs; you'll need to use a separate crate for that, like [`cudarc`](https://crates.io/crates/cudarc) for CUDA.
|
||||
///
|
||||
/// # Pinned allocation
|
||||
/// Memory allocated on the host CPU is often *pageable* and may reside on the disk (swap memory). Transferring
|
||||
/// pageable memory to another device is slow because the device has to go through the CPU to access the
|
||||
/// memory. Many execution providers thus provide a "pinned" allocator type, which allocates *unpaged* CPU memory
|
||||
/// that the device is able to access directly, bypassing the CPU and allowing for faster host-to-device data
|
||||
/// transfer.
|
||||
///
|
||||
/// If you create a session with a device allocator that supports pinned memory, like CUDA or ROCm, you can create
|
||||
/// an allocator for that session, and use it to allocate tensors with faster pinned memory:
|
||||
/// ```no_run
|
||||
/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let allocator = Allocator::new(
|
||||
/// &session,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
|
||||
/// )?;
|
||||
///
|
||||
/// // Create a tensor with our pinned allocator.
|
||||
/// let mut tensor = Tensor::<f32>::new(&allocator, [1, 3, 224, 224])?;
|
||||
/// let data = tensor.extract_tensor_mut();
|
||||
/// // ...fill `data` with data...
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct Allocator {
|
||||
pub(crate) ptr: NonNull<ort_sys::OrtAllocator>,
|
||||
/// The 'default' CPU allocator, provided by `GetAllocatorWithDefaultOptions` and implemented by
|
||||
/// [`Allocator::default`], should **not** be released, so this field marks whether or not we should call
|
||||
/// `ReleaseAllocator` on drop.
|
||||
is_default: bool,
|
||||
_info: Option<MemoryInfo>
|
||||
_info: Option<MemoryInfo>,
|
||||
/// Hold a reference to the session if this allocator is tied to one.
|
||||
_session_inner: Option<Arc<SharedSessionInner>>
|
||||
}
|
||||
|
||||
impl Allocator {
|
||||
@@ -22,47 +77,46 @@ impl Allocator {
|
||||
Allocator {
|
||||
ptr: NonNull::new_unchecked(ptr),
|
||||
is_default: false,
|
||||
// currently, this function is only ever used in session creation, where we call `CreateAllocator` manually and store the allocator resulting from
|
||||
// this function in the `SharedSessionInner` - we don't need to hold onto the session, because the session is holding onto us.
|
||||
_session_inner: None,
|
||||
_info: None
|
||||
}
|
||||
}
|
||||
|
||||
/// Frees an object allocated by this allocator, given the object's C pointer.
|
||||
pub(crate) unsafe fn free<T>(&self, ptr: *mut T) {
|
||||
self.ptr.as_ref().Free.unwrap_or_else(|| unreachable!("Allocator method `Free` is null"))(self.ptr.as_ptr(), ptr.cast());
|
||||
}
|
||||
|
||||
/// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the
|
||||
/// [`MemoryInfo`].
|
||||
///
|
||||
/// For example, to create an allocator to allocate pinned memory for CUDA:
|
||||
/// ```no_run
|
||||
/// # use ort::{Allocator, Session, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let allocator = Allocator::new(
|
||||
/// &session,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
|
||||
/// )?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(session: &Session, memory_info: MemoryInfo) -> Result<Self> {
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
|
||||
ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
|
||||
Ok(Self {
|
||||
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
|
||||
is_default: false,
|
||||
_session_inner: Some(session.inner()),
|
||||
_info: Some(memory_info)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Allocator {
|
||||
/// Returns the default CPU allocator; equivalent to `MemoryInfo::new(AllocationDevice::CPU, 0,
|
||||
/// AllocatorType::Device, MemoryType::Default)`.
|
||||
///
|
||||
/// The allocator returned by this function is actually shared across all invocations (though this behavior is
|
||||
/// transparent to the user).
|
||||
fn default() -> Self {
|
||||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
|
||||
status_to_result(ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]).expect("Failed to get default allocator");
|
||||
Self {
|
||||
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
|
||||
is_default: true,
|
||||
// The default allocator isn't tied to a session.
|
||||
_session_inner: None,
|
||||
_info: None
|
||||
}
|
||||
}
|
||||
@@ -70,8 +124,6 @@ impl Default for Allocator {
|
||||
|
||||
impl Drop for Allocator {
|
||||
fn drop(&mut self) {
|
||||
// per GetAllocatorWithDefaultOptions docs: Returned value should NOT be freed
|
||||
// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a8dec797ae52ee1a681e4f88be1fb4bb3
|
||||
if !self.is_default {
|
||||
ortsys![unsafe ReleaseAllocator(self.ptr.as_ptr())];
|
||||
}
|
||||
@@ -81,7 +133,8 @@ impl Drop for Allocator {
|
||||
/// Represents possible devices that have their own device allocator.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum AllocationDevice {
|
||||
// https://github.com/microsoft/onnxruntime/blob/v1.17.0/include/onnxruntime/core/framework/allocator.h#L43-L53
|
||||
// https://github.com/microsoft/onnxruntime/blob/v1.18.0/include/onnxruntime/core/framework/allocator.h#L43-L53
|
||||
// ort will likely never support WebGPU, so I think it's best to leave `WebGPU_Buffer` out entirely to reduce confusion
|
||||
CPU,
|
||||
CUDA,
|
||||
CUDAPinned,
|
||||
@@ -91,12 +144,10 @@ pub enum AllocationDevice {
|
||||
HIP,
|
||||
HIPPinned,
|
||||
OpenVINOCPU,
|
||||
OpenVINOGPU,
|
||||
WebGPUBuffer
|
||||
OpenVINOGPU
|
||||
}
|
||||
|
||||
impl AllocationDevice {
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::CPU => "Cpu",
|
||||
@@ -108,10 +159,15 @@ impl AllocationDevice {
|
||||
Self::HIP => "Hip",
|
||||
Self::HIPPinned => "HipPinned",
|
||||
Self::OpenVINOCPU => "OpenVINO_CPU",
|
||||
Self::OpenVINOGPU => "OpenVINO_GPU",
|
||||
Self::WebGPUBuffer => "WebGPU_Buffer"
|
||||
Self::OpenVINOGPU => "OpenVINO_GPU"
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
|
||||
/// it could be extracted to an `ndarray` or slice.
|
||||
pub fn is_cpu_accessible(&self) -> bool {
|
||||
matches!(self, Self::CPU | Self::CUDAPinned | Self::CANNPinned | Self::HIPPinned | Self::OpenVINOCPU)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for AllocationDevice {
|
||||
@@ -129,14 +185,13 @@ impl TryFrom<String> for AllocationDevice {
|
||||
"HipPinned" => Ok(AllocationDevice::HIPPinned),
|
||||
"OpenVINO_CPU" => Ok(AllocationDevice::OpenVINOCPU),
|
||||
"OpenVINO_GPU" => Ok(AllocationDevice::OpenVINOGPU),
|
||||
"WebGPUBuffer" => Ok(AllocationDevice::WebGPUBuffer),
|
||||
_ => Err(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execution provider allocator type.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum AllocatorType {
|
||||
/// Default device-specific allocator.
|
||||
Device,
|
||||
@@ -154,11 +209,11 @@ impl From<AllocatorType> for ort_sys::OrtAllocatorType {
|
||||
}
|
||||
|
||||
/// Memory types for allocated memory.
|
||||
#[derive(Default, Debug, Copy, Clone)]
|
||||
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum MemoryType {
|
||||
/// Any CPU memory used by non-CPU execution provider.
|
||||
CPUInput,
|
||||
/// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED.
|
||||
/// CPU-accessible memory output by a non-CPU execution provider, i.e. [`AllocatorDevice::CUDAPinned`].
|
||||
CPUOutput,
|
||||
/// The default allocator for an execution provider.
|
||||
#[default]
|
||||
@@ -190,6 +245,12 @@ impl From<ort_sys::OrtMemType> for MemoryType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Structure describing a memory location - the device on which the memory resides, the type of allocator (device
|
||||
/// default, or arena) used, and the type of memory allocated (device-only, or CPU accessible).
|
||||
///
|
||||
/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which
|
||||
/// device value data should reside, and how that data should be accessible with regard to the CPU (if a non-CPU device
|
||||
/// is requested).
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryInfo {
|
||||
pub(crate) ptr: NonNull<ort_sys::OrtMemoryInfo>,
|
||||
@@ -197,24 +258,24 @@ pub struct MemoryInfo {
|
||||
}
|
||||
|
||||
impl MemoryInfo {
|
||||
pub(crate) fn from_raw(ptr: NonNull<ort_sys::OrtMemoryInfo>, should_release: bool) -> Self {
|
||||
MemoryInfo { ptr, should_release }
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub fn new_cpu(allocator: AllocatorType, memory_type: MemoryType) -> Result<Self> {
|
||||
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
|
||||
ortsys![
|
||||
unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr) -> Error::CreateMemoryInfo;
|
||||
nonNull(memory_info_ptr)
|
||||
];
|
||||
Ok(Self {
|
||||
ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) },
|
||||
should_release: true
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
/// Creates a [`MemoryInfo`], describing a memory location on a device allocator.
|
||||
///
|
||||
/// # Examples
|
||||
/// `MemoryInfo` can be used to specify the device & memory type used by an [`Allocator`] to allocate tensors.
|
||||
/// See [`Allocator`] for more information & potential applications.
|
||||
/// ```no_run
|
||||
/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let allocator = Allocator::new(
|
||||
/// &session,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
|
||||
/// )?;
|
||||
///
|
||||
/// let mut tensor = Tensor::<f32>::new(&allocator, [1, 3, 224, 224])?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result<Self> {
|
||||
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
|
||||
let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!());
|
||||
@@ -229,7 +290,19 @@ impl MemoryInfo {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn from_raw(ptr: NonNull<ort_sys::OrtMemoryInfo>, should_release: bool) -> Self {
|
||||
MemoryInfo { ptr, should_release }
|
||||
}
|
||||
|
||||
/// Returns the [`MemoryType`] described by this struct.
|
||||
/// ```
|
||||
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
|
||||
/// assert_eq!(mem.memory_type()?, MemoryType::Default);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn memory_type(&self) -> Result<MemoryType> {
|
||||
let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault;
|
||||
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetMemoryType];
|
||||
@@ -237,6 +310,14 @@ impl MemoryInfo {
|
||||
}
|
||||
|
||||
/// Returns the [`AllocatorType`] described by this struct.
|
||||
/// ```
|
||||
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
|
||||
/// assert_eq!(mem.allocator_type()?, AllocatorType::Device);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn allocator_type(&self) -> Result<AllocatorType> {
|
||||
let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator;
|
||||
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetAllocatorType];
|
||||
@@ -248,6 +329,14 @@ impl MemoryInfo {
|
||||
}
|
||||
|
||||
/// Returns the [`AllocationDevice`] this struct was created with.
|
||||
/// ```
|
||||
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
|
||||
/// assert_eq!(mem.allocation_device()?, AllocationDevice::CPU);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn allocation_device(&self) -> Result<AllocationDevice> {
|
||||
let mut name_ptr: *const c_char = std::ptr::null_mut();
|
||||
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr) -> Error::GetAllocationDevice; nonNull(name_ptr)];
|
||||
@@ -258,6 +347,14 @@ impl MemoryInfo {
|
||||
}
|
||||
|
||||
/// Returns the ID of the [`AllocationDevice`] described by this struct.
|
||||
/// ```
|
||||
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
|
||||
/// assert_eq!(mem.device_id()?, 0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn device_id(&self) -> Result<i32> {
|
||||
let mut raw: ort_sys::c_int = 0;
|
||||
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw) -> Error::GetDeviceId];
|
||||
@@ -266,24 +363,9 @@ impl MemoryInfo {
|
||||
}
|
||||
|
||||
impl Drop for MemoryInfo {
|
||||
#[tracing::instrument]
|
||||
fn drop(&mut self) {
|
||||
if self.should_release {
|
||||
ortsys![unsafe ReleaseMemoryInfo(self.ptr.as_ptr())];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn create_memory_info() -> crate::Result<()> {
|
||||
let memory_info = MemoryInfo::new_cpu(AllocatorType::Device, MemoryType::Default)?;
|
||||
std::mem::drop(memory_info);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,8 +11,7 @@ use super::{
|
||||
};
|
||||
use crate::error::IntoStatus;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone)]
|
||||
#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later
|
||||
pub(crate) struct BoundOperator<O: Operator> {
|
||||
implementation: ort_sys::OrtCustomOp,
|
||||
name: CString,
|
||||
@@ -184,7 +183,10 @@ unsafe impl Send for ErasedBoundOperator {}
|
||||
|
||||
impl ErasedBoundOperator {
|
||||
pub(crate) fn new<O: Operator>(bound: BoundOperator<O>) -> Self {
|
||||
ErasedBoundOperator(NonNull::from(unsafe { &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) }))
|
||||
ErasedBoundOperator(NonNull::from(unsafe {
|
||||
// horrible horrible horrible horrible horrible horrible horrible horrible horrible
|
||||
&mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ())
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn op_ptr(&self) -> *mut ort_sys::OrtCustomOp {
|
||||
|
||||
@@ -138,4 +138,13 @@ impl KernelContext {
|
||||
ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx as ort_sys::size_t, shape.as_ptr(), shape.len() as _, &mut value_ptr) -> Error::GetOperatorOutput];
|
||||
Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) })))
|
||||
}
|
||||
|
||||
/// Returns a pointer to the GPU compute stream (i.e. `cudaStream_t`) used by the execution provider, if this
|
||||
/// kernel's operator was configured to use said execution provider (see
|
||||
/// [`super::Operator::execution_provider_type`]).
|
||||
pub fn compute_stream(&self) -> Result<Option<NonNull<ort_sys::c_void>>> {
|
||||
let mut stream_ptr: *mut ort_sys::c_void = ptr::null_mut();
|
||||
ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetOperatorGPUComputeStream];
|
||||
Ok(NonNull::new(stream_ptr))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,11 +16,26 @@ use crate::{operator::bound::BoundOperator, ortsys, Error, Result};
|
||||
|
||||
pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>;
|
||||
|
||||
/// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator.
|
||||
///
|
||||
/// [`Operator`]s are bound to [`OperatorDomain`]s. Multiple operators can have the same name as long as they have
|
||||
/// different input/output types, in which case the exact operator will be picked depending on the input/output
|
||||
/// types. If you want to, for example, define a `Sort` operator that can accept either a single `f32` or `i64` tensor
|
||||
/// input, you'll need to define 2 separate operators (which can be done via a macro); but both of these
|
||||
/// [`Operator`] structs can return the same name in [`Operator::name`] so that they are usable as simply
|
||||
/// `my.domain:Sort` in the graph.
|
||||
pub trait Operator: Send {
|
||||
type Kernel: Kernel;
|
||||
|
||||
/// Returns the name of the operator.
|
||||
fn name() -> &'static str;
|
||||
|
||||
/// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`.
|
||||
///
|
||||
/// If the returned type is not `None`, and the execution provider used by the session matches this operator's
|
||||
/// EP type, the value will not be copied to the CPU and you may use functions like [`crate::Tensor::data_ptr`] to
|
||||
/// access the underlying device memory, and [`super::KernelContext::compute_stream`] to access the GPU compute
|
||||
/// stream.
|
||||
fn execution_provider_type() -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
@@ -42,6 +57,7 @@ pub trait Operator: Send {
|
||||
}
|
||||
}
|
||||
|
||||
/// Dummy type implementing [`Operator`] used by [`ErasedBoundOperator`] to cheat the type system.
|
||||
struct DummyOperator;
|
||||
|
||||
impl Operator for DummyOperator {
|
||||
@@ -84,7 +100,7 @@ impl OperatorDomain {
|
||||
}
|
||||
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn add<O: Operator>(mut self, _operator: O) -> Result<Self> {
|
||||
pub fn add<O: Operator>(mut self) -> Result<Self> {
|
||||
let name = O::name();
|
||||
|
||||
let bound = BoundOperator::<O>::new(CString::new(name)?, O::execution_provider_type().map(CString::new).transpose()?);
|
||||
|
||||
@@ -92,16 +92,15 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Note that string tensors must be created manually with [`Value::from_string_array`].
|
||||
/// Note that string tensors must be created manually with [`crate::Tensor::from_string_array`].
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::{error::Error, sync::Arc};
|
||||
/// # use ndarray::Array1;
|
||||
/// # use ort::{GraphOptimizationLevel, Session, Value};
|
||||
/// # use ort::{GraphOptimizationLevel, Session, Tensor};
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
|
||||
/// let _ = session
|
||||
/// .run(ort::inputs![Value::from_string_array(session.allocator(), Array1::from_vec(vec!["hello", "world"]))?]?);
|
||||
/// let _ = session.run(ort::inputs![Tensor::from_string_array(Array1::from_vec(vec!["hello", "world"]))?]?);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
|
||||
@@ -22,4 +22,4 @@ mod types;
|
||||
pub use self::ndarray::ArrayExtensions;
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub(crate) use self::types::{extract_primitive_array, extract_primitive_array_mut};
|
||||
pub use self::types::{IntoTensorElementType, TensorElementType, Utf8Data};
|
||||
pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};
|
||||
|
||||
@@ -91,6 +91,12 @@ impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
|
||||
pub trait IntoTensorElementType {
|
||||
/// Returns the ONNX tensor element data type corresponding to the given Rust type.
|
||||
fn into_tensor_element_type() -> TensorElementType;
|
||||
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
pub trait PrimitiveTensorElementType: IntoTensorElementType {
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
macro_rules! impl_type_trait {
|
||||
@@ -99,6 +105,12 @@ macro_rules! impl_type_trait {
|
||||
fn into_tensor_element_type() -> TensorElementType {
|
||||
TensorElementType::$variant
|
||||
}
|
||||
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
impl PrimitiveTensorElementType for $type_ {
|
||||
crate::private_impl!();
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -121,6 +133,14 @@ impl_type_trait!(u64, Uint64);
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
|
||||
impl_type_trait!(half::bf16, Bfloat16);
|
||||
|
||||
impl IntoTensorElementType for String {
|
||||
fn into_tensor_element_type() -> TensorElementType {
|
||||
TensorElementType::String
|
||||
}
|
||||
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
/// Adapter for common Rust string types to ONNX strings.
|
||||
pub trait Utf8Data {
|
||||
/// Returns the contents of this value as a slice of UTF-8 bytes.
|
||||
|
||||
@@ -3,25 +3,39 @@ use std::{
|
||||
fmt::Debug,
|
||||
hash::Hash,
|
||||
marker::PhantomData,
|
||||
ptr::{self, NonNull}
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
};
|
||||
|
||||
use super::{ValueInner, ValueTypeMarker};
|
||||
use crate::{
|
||||
memory::Allocator, ortsys, value::impl_tensor::DynTensor, DynValue, Error, IntoTensorElementType, Result, Tensor, Value, ValueRef, ValueRefMut, ValueType
|
||||
memory::Allocator,
|
||||
ortsys,
|
||||
value::impl_tensor::{calculate_tensor_size, DynTensor},
|
||||
DynValue, Error, IntoTensorElementType, PrimitiveTensorElementType, Result, Tensor, TensorElementType, Value, ValueRef, ValueRefMut, ValueType
|
||||
};
|
||||
|
||||
pub trait MapValueTypeMarker: ValueTypeMarker {}
|
||||
pub trait MapValueTypeMarker: ValueTypeMarker {
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DynMapValueType;
|
||||
impl ValueTypeMarker for DynMapValueType {}
|
||||
impl MapValueTypeMarker for DynMapValueType {}
|
||||
impl ValueTypeMarker for DynMapValueType {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl MapValueTypeMarker for DynMapValueType {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MapValueType<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Debug>(PhantomData<(K, V)>);
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {}
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {}
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
pub type DynMap = Value<DynMapValueType>;
|
||||
pub type Map<K, V> = Value<MapValueType<K, V>>;
|
||||
@@ -32,10 +46,7 @@ pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType<K, V>>;
|
||||
pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType<K, V>>;
|
||||
|
||||
impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Clone>(
|
||||
&self,
|
||||
allocator: &Allocator
|
||||
) -> Result<HashMap<K, V>> {
|
||||
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
|
||||
match self.dtype()? {
|
||||
ValueType::Map { key, value } => {
|
||||
let k_type = K::into_tensor_element_type();
|
||||
@@ -47,47 +58,95 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
|
||||
return Err(Error::InvalidMapValueType { expected: v_type, actual: value });
|
||||
}
|
||||
|
||||
let allocator = Allocator::default();
|
||||
|
||||
let mut key_tensor_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr) -> Error::ExtractMap; nonNull(key_tensor_ptr)];
|
||||
let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) };
|
||||
let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_tensor::<K>()?;
|
||||
if K::into_tensor_element_type() != TensorElementType::String {
|
||||
let dtype = key_value.dtype()?;
|
||||
let (key_tensor_shape, key_tensor) = match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = key_value.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let mut value_tensor_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)];
|
||||
let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) };
|
||||
let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::<V>()?;
|
||||
if ty == K::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut K = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
|
||||
assert_eq!(key_tensor_shape.len(), 1);
|
||||
assert_eq!(value_tensor_shape.len(), 1);
|
||||
assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
|
||||
let len = calculate_tensor_size(&dimensions);
|
||||
(dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })
|
||||
} else {
|
||||
return Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: K::into_tensor_element_type()
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => unreachable!()
|
||||
};
|
||||
|
||||
let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
|
||||
for i in 0..key_tensor_shape[0] as usize {
|
||||
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
|
||||
let mut value_tensor_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)];
|
||||
let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) };
|
||||
let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::<V>()?;
|
||||
|
||||
assert_eq!(key_tensor_shape.len(), 1);
|
||||
assert_eq!(value_tensor_shape.len(), 1);
|
||||
assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
|
||||
|
||||
let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
|
||||
for i in 0..key_tensor_shape[0] as usize {
|
||||
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
|
||||
}
|
||||
Ok(vec.into_iter().collect())
|
||||
} else {
|
||||
let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_string_tensor()?;
|
||||
// SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`,
|
||||
// so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type
|
||||
// checker.
|
||||
let key_tensor: Vec<K> = unsafe { std::mem::transmute(key_tensor) };
|
||||
|
||||
let mut value_tensor_ptr = ptr::null_mut();
|
||||
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)];
|
||||
let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) };
|
||||
let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::<V>()?;
|
||||
|
||||
assert_eq!(key_tensor_shape.len(), 1);
|
||||
assert_eq!(value_tensor_shape.len(), 1);
|
||||
assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
|
||||
|
||||
let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
|
||||
for i in 0..key_tensor_shape[0] as usize {
|
||||
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
|
||||
}
|
||||
Ok(vec.into_iter().collect())
|
||||
}
|
||||
Ok(vec.into_iter().collect())
|
||||
}
|
||||
t => Err(Error::NotMap(t))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
|
||||
impl<K: PrimitiveTensorElementType + Debug + Clone + Hash + Eq + 'static, V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
|
||||
/// Creates a [`Map`] from an iterable emitting `K` and `V`.
|
||||
///
|
||||
/// ```
|
||||
/// # use std::collections::HashMap;
|
||||
/// # use ort::{Allocator, Map};
|
||||
/// # use ort::Map;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let allocator = Allocator::default();
|
||||
/// let mut map = HashMap::<i64, f32>::new();
|
||||
/// map.insert(0, 1.0);
|
||||
/// map.insert(1, 2.0);
|
||||
/// map.insert(2, 3.0);
|
||||
///
|
||||
/// let value = Map::new(map)?;
|
||||
/// let value = Map::<i64, f32>::new(map)?;
|
||||
///
|
||||
/// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0);
|
||||
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -95,20 +154,45 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTens
|
||||
let (keys, values): (Vec<K>, Vec<V>) = data.into_iter().unzip();
|
||||
Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<String, V>> {
|
||||
/// Creates a [`Map`] from an iterable emitting `K` and `V`.
|
||||
///
|
||||
/// ```
|
||||
/// # use std::collections::HashMap;
|
||||
/// # use ort::Map;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut map = HashMap::<i64, f32>::new();
|
||||
/// map.insert(0, 1.0);
|
||||
/// map.insert(1, 2.0);
|
||||
/// map.insert(2, 3.0);
|
||||
///
|
||||
/// let value = Map::<i64, f32>::new(map)?;
|
||||
///
|
||||
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(data: impl IntoIterator<Item = (String, V)>) -> Result<Self> {
|
||||
let (keys, values): (Vec<String>, Vec<V>) = data.into_iter().unzip();
|
||||
Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
|
||||
/// Creates a [`Map`] from two tensors of keys & values respectively.
|
||||
///
|
||||
/// ```
|
||||
/// # use std::collections::HashMap;
|
||||
/// # use ort::{Allocator, Map, Tensor};
|
||||
/// # use ort::{Map, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let allocator = Allocator::default();
|
||||
/// let keys = Tensor::<i64>::from_array(([4], vec![0, 1, 2, 3]))?;
|
||||
/// let values = Tensor::<f32>::from_array(([4], vec![1., 2., 3., 4.]))?;
|
||||
///
|
||||
/// let value = Map::new_kv(keys, values)?;
|
||||
///
|
||||
/// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0);
|
||||
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
@@ -122,21 +206,23 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTens
|
||||
nonNull(value_ptr)
|
||||
];
|
||||
Ok(Value {
|
||||
inner: ValueInner::RustOwned {
|
||||
inner: Arc::new(ValueInner::RustOwned {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
_array: Box::new(values),
|
||||
_memory_info: None
|
||||
},
|
||||
}),
|
||||
_markers: PhantomData
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
|
||||
pub fn extract_map(&self, allocator: &Allocator) -> HashMap<K, V> {
|
||||
self.try_extract_map(allocator).expect("Failed to extract map")
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: PrimitiveTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
|
||||
pub fn extract_map(&self) -> HashMap<K, V> {
|
||||
self.try_extract_map().expect("Failed to extract map")
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
|
||||
/// Converts from a strongly-typed [`Map<K, V>`] to a type-erased [`DynMap`].
|
||||
#[inline]
|
||||
pub fn upcast(self) -> DynMap {
|
||||
@@ -149,7 +235,7 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
|
||||
DynMapRef::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -160,7 +246,7 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
|
||||
DynMapRefMut::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,23 +1,34 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
ptr::{self, NonNull}
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
};
|
||||
|
||||
use super::{DowncastableTarget, ValueInner, ValueTypeMarker};
|
||||
use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType};
|
||||
|
||||
pub trait SequenceValueTypeMarker: ValueTypeMarker {}
|
||||
pub trait SequenceValueTypeMarker: ValueTypeMarker {
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DynSequenceValueType;
|
||||
impl ValueTypeMarker for DynSequenceValueType {}
|
||||
impl SequenceValueTypeMarker for DynSequenceValueType {}
|
||||
impl ValueTypeMarker for DynSequenceValueType {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl SequenceValueTypeMarker for DynSequenceValueType {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SequenceValueType<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(PhantomData<T>);
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {}
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {}
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
pub type DynSequence = Value<DynSequenceValueType>;
|
||||
pub type Sequence<T> = Value<SequenceValueType<T>>;
|
||||
@@ -89,11 +100,11 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<Se
|
||||
nonNull(value_ptr)
|
||||
];
|
||||
Ok(Value {
|
||||
inner: ValueInner::RustOwned {
|
||||
inner: Arc::new(ValueInner::RustOwned {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
_array: Box::new(values),
|
||||
_memory_info: None
|
||||
},
|
||||
}),
|
||||
_markers: PhantomData
|
||||
})
|
||||
}
|
||||
@@ -116,7 +127,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
|
||||
DynSequenceRef::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -127,7 +138,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
|
||||
DynSequenceRefMut::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,15 +15,15 @@ use crate::{
|
||||
error::assert_non_null_pointer,
|
||||
memory::{Allocator, MemoryInfo},
|
||||
ortsys,
|
||||
tensor::{IntoTensorElementType, TensorElementType, Utf8Data},
|
||||
value::ValueInner,
|
||||
AllocatorType, DynValue, Error, MemoryType, Result, TensorRefMut, Value
|
||||
tensor::{TensorElementType, Utf8Data},
|
||||
value::{impl_tensor::calculate_tensor_size, ValueInner},
|
||||
AllocationDevice, AllocatorType, DynValue, Error, MemoryType, PrimitiveTensorElementType, Result, TensorRefMut, Value
|
||||
};
|
||||
|
||||
impl DynTensor {
|
||||
/// Construct a [`Value`] from an array of strings.
|
||||
impl Tensor<String> {
|
||||
/// Construct a [`DynTensor`] from an array of strings.
|
||||
///
|
||||
/// Just like numeric tensors, string tensor `Value`s can be created from:
|
||||
/// Just like numeric tensors, string tensors can be created from:
|
||||
/// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`);
|
||||
/// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray<T, D>`);
|
||||
/// - (with feature `ndarray`) an owned [`ndarray::Array`];
|
||||
@@ -36,26 +36,19 @@ impl DynTensor {
|
||||
/// ```
|
||||
/// # use ort::{Session, Value};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?;
|
||||
/// // You'll need to obtain an `Allocator` from a session in order to create string tensors.
|
||||
/// let allocator = session.allocator();
|
||||
///
|
||||
/// // Create a string tensor from a raw data vector
|
||||
/// let data = vec!["hello", "world"];
|
||||
/// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?;
|
||||
/// let value = Value::from_string_array(([data.len()], data.into_boxed_slice()))?;
|
||||
///
|
||||
/// // Create a string tensor from an `ndarray::Array`
|
||||
/// #[cfg(feature = "ndarray")]
|
||||
/// let value = Value::from_string_array(
|
||||
/// allocator,
|
||||
/// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()
|
||||
/// )?;
|
||||
/// let value = Value::from_string_array(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Note that string data will *always* be copied, no matter what form the data is provided in.
|
||||
pub fn from_string_array<T: Utf8Data>(allocator: &Allocator, input: impl IntoValueTensor<Item = T>) -> Result<DynTensor> {
|
||||
pub fn from_string_array<T: Utf8Data>(input: impl IntoValueTensor<Item = T>) -> Result<Tensor<String>> {
|
||||
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
|
||||
|
||||
let (shape, data) = input.ref_parts()?;
|
||||
@@ -64,7 +57,7 @@ impl DynTensor {
|
||||
|
||||
// create tensor without data -- data is filled in later
|
||||
ortsys![
|
||||
unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr)
|
||||
unsafe CreateTensorAsOrtValue(Allocator::default().ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr)
|
||||
-> Error::CreateTensor;
|
||||
nonNull(value_ptr)
|
||||
];
|
||||
@@ -84,18 +77,18 @@ impl DynTensor {
|
||||
ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor];
|
||||
|
||||
Ok(Value {
|
||||
inner: ValueInner::RustOwned {
|
||||
inner: Arc::new(ValueInner::RustOwned {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
_array: Box::new(()),
|
||||
_memory_info: None
|
||||
},
|
||||
}),
|
||||
_markers: PhantomData
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
/// Construct a tensor [`Value`] in a given allocator with a given shape and datatype. The data contained in the
|
||||
impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
/// Construct a tensor in a given allocator with a given shape and datatype. The data contained in the
|
||||
/// value will be zero-allocated on the allocation device.
|
||||
///
|
||||
/// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned
|
||||
@@ -132,18 +125,18 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
];
|
||||
|
||||
Ok(Value {
|
||||
inner: ValueInner::RustOwned {
|
||||
inner: Arc::new(ValueInner::RustOwned {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
_array: Box::new(()),
|
||||
_memory_info: None
|
||||
},
|
||||
}),
|
||||
_markers: PhantomData
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct a tensor [`Value`] from an array of data.
|
||||
/// Construct a tensor from an array of data.
|
||||
///
|
||||
/// Tensor `Value`s can be created from:
|
||||
/// Tensors can be created from:
|
||||
/// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`);
|
||||
/// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray<T, D>`);
|
||||
/// - (with feature `ndarray`) an owned [`ndarray::Array`];
|
||||
@@ -154,19 +147,19 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
/// * and `data` is one of `Vec<T>`, `Box<[T]>`, `Arc<Box<[T]>>`, or `&[T]`.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::Value;
|
||||
/// # use ort::Tensor;
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// // Create a tensor from a raw data vector
|
||||
/// let value = Value::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?;
|
||||
/// let tensor = Tensor::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?;
|
||||
///
|
||||
/// // Create a tensor from an `ndarray::Array`
|
||||
/// #[cfg(feature = "ndarray")]
|
||||
/// let value = Value::from_array(ndarray::Array4::<f32>::zeros((1, 16, 16, 3)))?;
|
||||
/// let tensor = Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 16, 16, 3)))?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Creating string tensors requires a separate method; see [`Value::from_string_array`].
|
||||
/// Creating string tensors requires a separate method; see [`DynTensor::from_string_array`].
|
||||
///
|
||||
/// Note that data provided in an `ndarray` may be copied in some circumstances:
|
||||
/// - `&CowArray<'_, T, D>` will always be copied regardless of whether it is uniquely owned or borrowed.
|
||||
@@ -177,7 +170,7 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
/// Raw data provided as a `Arc<Box<[T]>>`, `Box<[T]>`, or `Vec<T>` will never be copied. Raw data is expected to be
|
||||
/// in standard, contigous layout.
|
||||
pub fn from_array(input: impl IntoValueTensor<Item = T>) -> Result<Tensor<T>> {
|
||||
let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?;
|
||||
let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::Default)?;
|
||||
|
||||
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
|
||||
|
||||
@@ -203,17 +196,17 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
];
|
||||
|
||||
Ok(Value {
|
||||
inner: ValueInner::RustOwned {
|
||||
inner: Arc::new(ValueInner::RustOwned {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
_array: guard,
|
||||
_memory_info: Some(memory_info)
|
||||
},
|
||||
}),
|
||||
_markers: PhantomData
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
/// Create a mutable tensor view from a raw pointer and shape.
|
||||
///
|
||||
/// The length of data is determined by `T` and the given shape, so the given buffer must be at least
|
||||
@@ -260,11 +253,11 @@ impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> {
|
||||
];
|
||||
|
||||
Ok(TensorRefMut::new(Value {
|
||||
inner: ValueInner::CppOwned {
|
||||
inner: Arc::new(ValueInner::CppOwned {
|
||||
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
|
||||
drop: true,
|
||||
_session: None
|
||||
},
|
||||
}),
|
||||
_markers: PhantomData
|
||||
}))
|
||||
}
|
||||
@@ -290,7 +283,7 @@ macro_rules! impl_to_dimensions {
|
||||
.enumerate()
|
||||
.map(|(i, c)| if *c >= 1 { Ok(*c as i64) } else { Err(Error::InvalidDimension(i)) })
|
||||
.collect::<Result<_>>()?;
|
||||
let sum = v.iter().product::<i64>() as usize;
|
||||
let sum = calculate_tensor_size(&v);
|
||||
if let Some(expected_size) = expected_size {
|
||||
if sum != expected_size {
|
||||
Err(Error::TensorShapeMismatch {
|
||||
@@ -318,6 +311,14 @@ macro_rules! impl_to_dimensions {
|
||||
};
|
||||
}
|
||||
|
||||
impl ToDimensions for () {
|
||||
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>> {
|
||||
match expected_size {
|
||||
Some(1) | None => Ok(vec![]),
|
||||
Some(x) => Err(Error::TensorShapeMismatch { input: vec![], total: 1, expected: x })
|
||||
}
|
||||
}
|
||||
}
|
||||
impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec<usize>, for Vec<i32>, for Vec<i64>);
|
||||
impl_to_dimensions!(<N> for [usize; N], for [i32; N], for [i64; N]);
|
||||
|
||||
@@ -500,7 +501,7 @@ impl<T: Clone + Debug + 'static, D: ToDimensions> IntoValueTensor for (D, Arc<Bo
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor<T>
|
||||
impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor<T>
|
||||
where
|
||||
'i: 'v
|
||||
{
|
||||
@@ -512,7 +513,7 @@ where
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for Tensor<T> {
|
||||
impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for Tensor<T> {
|
||||
type Error = Error;
|
||||
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(arr)
|
||||
@@ -521,7 +522,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor
|
||||
impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor
|
||||
where
|
||||
'i: 'v
|
||||
{
|
||||
@@ -533,7 +534,7 @@ where
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynTensor {
|
||||
impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynTensor {
|
||||
type Error = Error;
|
||||
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(arr).map(|c| c.upcast())
|
||||
@@ -542,7 +543,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue
|
||||
impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue
|
||||
where
|
||||
'i: 'v
|
||||
{
|
||||
@@ -554,7 +555,7 @@ where
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynValue {
|
||||
impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynValue {
|
||||
type Error = Error;
|
||||
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(arr).map(|c| c.into_dyn())
|
||||
@@ -564,19 +565,19 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta
|
||||
macro_rules! impl_try_from {
|
||||
(@T,I $($t:ty),+) => {
|
||||
$(
|
||||
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for Tensor<T> {
|
||||
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for Tensor<T> {
|
||||
type Error = Error;
|
||||
fn try_from(value: $t) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(value)
|
||||
}
|
||||
}
|
||||
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for DynTensor {
|
||||
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for DynTensor {
|
||||
type Error = Error;
|
||||
fn try_from(value: $t) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(value).map(|c| c.upcast())
|
||||
}
|
||||
}
|
||||
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for crate::DynValue {
|
||||
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for crate::DynValue {
|
||||
type Error = Error;
|
||||
fn try_from(value: $t) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(value).map(|c| c.into_dyn())
|
||||
@@ -587,21 +588,21 @@ macro_rules! impl_try_from {
|
||||
(@T,D $($t:ty),+) => {
|
||||
$(
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for Tensor<T> {
|
||||
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for Tensor<T> {
|
||||
type Error = Error;
|
||||
fn try_from(value: $t) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(value)
|
||||
}
|
||||
}
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for DynTensor {
|
||||
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for DynTensor {
|
||||
type Error = Error;
|
||||
fn try_from(value: $t) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(value).map(|c| c.upcast())
|
||||
}
|
||||
}
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for crate::DynValue {
|
||||
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for crate::DynValue {
|
||||
type Error = Error;
|
||||
fn try_from(value: $t) -> Result<Self, Self::Error> {
|
||||
Tensor::from_array(value).map(|c| c.into_dyn())
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{fmt::Debug, os::raw::c_char, ptr, string::FromUtf8Error};
|
||||
use std::{fmt::Debug, ptr, string::FromUtf8Error};
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
use ndarray::IxDyn;
|
||||
@@ -7,9 +7,7 @@ use super::TensorValueTypeMarker;
|
||||
#[cfg(feature = "ndarray")]
|
||||
use crate::tensor::{extract_primitive_array, extract_primitive_array_mut};
|
||||
use crate::{
|
||||
ortsys,
|
||||
tensor::{IntoTensorElementType, TensorElementType},
|
||||
Error, Result, Tensor, Value
|
||||
ortsys, tensor::TensorElementType, value::impl_tensor::calculate_tensor_size, Error, PrimitiveTensorElementType, Result, Tensor, Value, ValueType
|
||||
};
|
||||
|
||||
impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
@@ -38,38 +36,81 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
|
||||
/// infallible [`Tensor::extract_tensor`] instead)*
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
/// - The tensor's data is not allocated in CPU memory.
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn try_extract_tensor<T: IntoTensorElementType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
|
||||
pub fn try_extract_tensor<T: PrimitiveTensorElementType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let res = {
|
||||
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
|
||||
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
let data_type: TensorElementType = type_sys.into();
|
||||
if data_type == T::into_tensor_element_type() {
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
|
||||
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
|
||||
let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::<Vec<_>>());
|
||||
|
||||
let mut len = 0;
|
||||
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
|
||||
|
||||
Ok(extract_primitive_array(shape, self.ptr())?)
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: data_type,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
if ty == T::into_tensor_element_type() {
|
||||
Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>()), self.ptr())?)
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
|
||||
res
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to extract the scalar from a tensor of type `T`.
|
||||
///
|
||||
/// ```
|
||||
/// # use std::sync::Arc;
|
||||
/// # use ort::{Session, Value};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let value = Value::from_array(((), vec![3.14_f32]))?;
|
||||
///
|
||||
/// let extracted = value.try_extract_scalar::<f32>()?;
|
||||
/// assert_eq!(extracted, 3.14);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
/// May return an error if:
|
||||
/// - The tensor is not 0-dimensional.
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
|
||||
/// infallible [`Tensor::extract_tensor`] instead)*
|
||||
/// - The tensor's data is not allocated in CPU memory.
|
||||
pub fn try_extract_scalar<T: PrimitiveTensorElementType + Copy>(&self) -> Result<T> {
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
if !dimensions.is_empty() {
|
||||
return Err(Error::TensorNot0Dimensional(dimensions.len()));
|
||||
}
|
||||
|
||||
if ty == T::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
|
||||
Ok(unsafe { *output_array_ptr })
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
}
|
||||
}
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`].
|
||||
@@ -101,36 +142,26 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn try_extract_tensor_mut<T: IntoTensorElementType>(&mut self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
|
||||
pub fn try_extract_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let res = {
|
||||
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
|
||||
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
let data_type: TensorElementType = type_sys.into();
|
||||
if data_type == T::into_tensor_element_type() {
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
|
||||
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
|
||||
let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::<Vec<_>>());
|
||||
|
||||
let mut len = 0;
|
||||
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
|
||||
|
||||
Ok(extract_primitive_array_mut(shape, self.ptr())?)
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: data_type,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
if ty == T::into_tensor_element_type() {
|
||||
Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>()), self.ptr())?)
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
|
||||
res
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an
|
||||
@@ -159,40 +190,32 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
|
||||
/// infallible [`Tensor::extract_raw_tensor`] instead)*
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
pub fn try_extract_raw_tensor<T: IntoTensorElementType>(&self) -> Result<(Vec<i64>, &[T])> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
|
||||
pub fn try_extract_raw_tensor<T: PrimitiveTensorElementType>(&self) -> Result<(Vec<i64>, &[T])> {
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let res = {
|
||||
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
|
||||
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
let data_type: TensorElementType = type_sys.into();
|
||||
if data_type == T::into_tensor_element_type() {
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
|
||||
if ty == T::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
|
||||
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
|
||||
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
|
||||
let mut len = 0;
|
||||
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
|
||||
|
||||
Ok((node_dims, unsafe { std::slice::from_raw_parts(output_array_ptr, len as _) }))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: data_type,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
let len = calculate_tensor_size(&dimensions);
|
||||
Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
|
||||
res
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a
|
||||
@@ -218,50 +241,41 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
|
||||
/// infallible [`Tensor::extract_raw_tensor_mut`] instead)*
|
||||
/// - The provided type `T` does not match the tensor's element type.
|
||||
pub fn try_extract_raw_tensor_mut<T: IntoTensorElementType>(&mut self) -> Result<(Vec<i64>, &mut [T])> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
|
||||
pub fn try_extract_raw_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<(Vec<i64>, &mut [T])> {
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let res = {
|
||||
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
|
||||
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
let data_type: TensorElementType = type_sys.into();
|
||||
if data_type == T::into_tensor_element_type() {
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
|
||||
if ty == T::into_tensor_element_type() {
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
|
||||
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
|
||||
|
||||
let mut output_array_ptr: *mut T = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
|
||||
let mut len = 0;
|
||||
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
|
||||
|
||||
Ok((node_dims, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len as _) }))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: data_type,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
let len = calculate_tensor_size(&dimensions);
|
||||
Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) }))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: T::into_tensor_element_type()
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
|
||||
res
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying data into a Rust `ndarray`.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, Session, DynTensor, TensorElementType};
|
||||
/// # use ort::{Session, Tensor, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let allocator = Allocator::default();
|
||||
/// let array = ndarray::Array1::from_vec(vec!["hello", "world"]);
|
||||
/// let tensor = DynTensor::from_string_array(&allocator, array.clone())?;
|
||||
/// let tensor = Tensor::from_string_array(array.clone())?;
|
||||
///
|
||||
/// let extracted = tensor.try_extract_string_tensor()?;
|
||||
/// assert_eq!(array.into_dyn(), extracted);
|
||||
@@ -271,78 +285,68 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
#[cfg(feature = "ndarray")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub fn try_extract_string_tensor(&self) -> Result<ndarray::ArrayD<String>> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let res = {
|
||||
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
|
||||
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
let data_type: TensorElementType = type_sys.into();
|
||||
if data_type == TensorElementType::String {
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
|
||||
if ty == TensorElementType::String {
|
||||
let len = calculate_tensor_size(&dimensions);
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
|
||||
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
|
||||
let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::<Vec<_>>());
|
||||
// Total length of string data, not including \0 suffix
|
||||
let mut total_length: ort_sys::size_t = 0;
|
||||
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
|
||||
|
||||
let mut len: ort_sys::size_t = 0;
|
||||
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
|
||||
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
|
||||
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
|
||||
// don't seem to be written to in practice either.
|
||||
// If the string data actually did go farther, it would panic below when using the offset
|
||||
// data to get slices for each string.
|
||||
let mut string_contents = vec![0u8; total_length as _];
|
||||
// one extra slot so that the total length can go in the last one, making all per-string
|
||||
// length calculations easy
|
||||
let mut offsets = vec![0; (len + 1) as _];
|
||||
|
||||
// Total length of string data, not including \0 suffix
|
||||
let mut total_length: ort_sys::size_t = 0;
|
||||
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
|
||||
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent];
|
||||
|
||||
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
|
||||
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
|
||||
// don't seem to be written to in practice either.
|
||||
// If the string data actually did go farther, it would panic below when using the offset
|
||||
// data to get slices for each string.
|
||||
let mut string_contents = vec![0u8; total_length as _];
|
||||
// one extra slot so that the total length can go in the last one, making all per-string
|
||||
// length calculations easy
|
||||
let mut offsets = vec![0; (len + 1) as _];
|
||||
// final offset = overall length so that per-string length calculations work for the last string
|
||||
debug_assert_eq!(0, offsets[len]);
|
||||
offsets[len] = total_length;
|
||||
|
||||
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent];
|
||||
let strings = offsets
|
||||
// offsets has 1 extra offset past the end so that all windows work
|
||||
.windows(2)
|
||||
.map(|w| {
|
||||
let slice = &string_contents[w[0] as _..w[1] as _];
|
||||
String::from_utf8(slice.into())
|
||||
})
|
||||
.collect::<Result<Vec<String>, FromUtf8Error>>()
|
||||
.map_err(Error::StringFromUtf8Error)?;
|
||||
|
||||
// final offset = overall length so that per-string length calculations work for the last string
|
||||
debug_assert_eq!(0, offsets[len as usize]);
|
||||
offsets[len as usize] = total_length;
|
||||
|
||||
let strings = offsets
|
||||
// offsets has 1 extra offset past the end so that all windows work
|
||||
.windows(2)
|
||||
.map(|w| {
|
||||
let slice = &string_contents[w[0] as _..w[1] as _];
|
||||
String::from_utf8(slice.into())
|
||||
Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>()), strings)
|
||||
.expect("Shape extracted from tensor didn't match tensor contents"))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: TensorElementType::String
|
||||
})
|
||||
.collect::<Result<Vec<String>, FromUtf8Error>>()
|
||||
.map_err(Error::StringFromUtf8Error)?;
|
||||
|
||||
Ok(ndarray::Array::from_shape_vec(shape, strings)
|
||||
.expect("Shape extracted from tensor didn't match tensor contents")
|
||||
.into_dyn())
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: data_type,
|
||||
requested: TensorElementType::String
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
|
||||
res
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and
|
||||
/// an owned `Vec` of its data.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, Session, DynTensor, TensorElementType};
|
||||
/// # use ort::{Session, Tensor, TensorElementType};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// # let allocator = Allocator::default();
|
||||
/// let array = vec!["hello", "world"];
|
||||
/// let tensor = DynTensor::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?;
|
||||
/// let tensor = Tensor::from_string_array(([array.len()], array.clone().into_boxed_slice()))?;
|
||||
///
|
||||
/// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?;
|
||||
/// assert_eq!(extracted_data, array);
|
||||
@@ -351,68 +355,57 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec<i64>, Vec<String>)> {
|
||||
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
|
||||
let dtype = self.dtype()?;
|
||||
match dtype {
|
||||
ValueType::Tensor { ty, dimensions } => {
|
||||
let device = self.memory_info()?.allocation_device()?;
|
||||
if !device.is_cpu_accessible() {
|
||||
return Err(Error::TensorNotOnCpu(device.as_str()));
|
||||
}
|
||||
|
||||
let res = {
|
||||
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
|
||||
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
let data_type: TensorElementType = type_sys.into();
|
||||
if data_type == TensorElementType::String {
|
||||
let mut num_dims = 0;
|
||||
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
|
||||
if ty == TensorElementType::String {
|
||||
let len = calculate_tensor_size(&dimensions);
|
||||
|
||||
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
|
||||
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
|
||||
// Total length of string data, not including \0 suffix
|
||||
let mut total_length: ort_sys::size_t = 0;
|
||||
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
|
||||
|
||||
let mut output_array_ptr: *mut c_char = ptr::null_mut();
|
||||
let output_array_ptr_ptr: *mut *mut c_char = &mut output_array_ptr;
|
||||
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
|
||||
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
|
||||
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
|
||||
// don't seem to be written to in practice either.
|
||||
// If the string data actually did go farther, it would panic below when using the offset
|
||||
// data to get slices for each string.
|
||||
let mut string_contents = vec![0u8; total_length as _];
|
||||
// one extra slot so that the total length can go in the last one, making all per-string
|
||||
// length calculations easy
|
||||
let mut offsets = vec![0; (len + 1) as _];
|
||||
|
||||
let mut len: ort_sys::size_t = 0;
|
||||
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
|
||||
// Total length of string data, not including \0 suffix
|
||||
let mut total_length = 0;
|
||||
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
|
||||
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent];
|
||||
|
||||
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
|
||||
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
|
||||
// don't seem to be written to in practice either.
|
||||
// If the string data actually did go farther, it would panic below when using the offset
|
||||
// data to get slices for each string.
|
||||
let mut string_contents = vec![0u8; total_length as _];
|
||||
// one extra slot so that the total length can go in the last one, making all per-string
|
||||
// length calculations easy
|
||||
let mut offsets = vec![0; len as usize + 1];
|
||||
// final offset = overall length so that per-string length calculations work for the last string
|
||||
debug_assert_eq!(0, offsets[len]);
|
||||
offsets[len] = total_length;
|
||||
|
||||
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length as _, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent];
|
||||
let strings = offsets
|
||||
// offsets has 1 extra offset past the end so that all windows work
|
||||
.windows(2)
|
||||
.map(|w| {
|
||||
let slice = &string_contents[w[0] as _..w[1] as _];
|
||||
String::from_utf8(slice.into())
|
||||
})
|
||||
.collect::<Result<Vec<String>, FromUtf8Error>>()
|
||||
.map_err(Error::StringFromUtf8Error)?;
|
||||
|
||||
// final offset = overall length so that per-string length calculations work for the last string
|
||||
debug_assert_eq!(0, offsets[len as usize]);
|
||||
offsets[len as usize] = total_length;
|
||||
|
||||
let strings = offsets
|
||||
// offsets has 1 extra offset past the end so that all windows work
|
||||
.windows(2)
|
||||
.map(|w| {
|
||||
let slice = &string_contents[w[0] as _..w[1] as _];
|
||||
String::from_utf8(slice.into())
|
||||
Ok((dimensions, strings))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: ty,
|
||||
requested: TensorElementType::String
|
||||
})
|
||||
.collect::<Result<Vec<String>, FromUtf8Error>>()
|
||||
.map_err(Error::StringFromUtf8Error)?;
|
||||
|
||||
Ok((node_dims, strings))
|
||||
} else {
|
||||
Err(Error::DataTypeMismatch {
|
||||
actual: data_type,
|
||||
requested: TensorElementType::String
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
|
||||
res
|
||||
t => Err(Error::NotTensor(t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the shape of the tensor.
|
||||
@@ -445,7 +438,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
|
||||
/// Extracts the underlying data into a read-only [`ndarray::ArrayView`].
|
||||
///
|
||||
/// ```
|
||||
|
||||
@@ -11,41 +11,96 @@ use std::{
|
||||
use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker};
|
||||
use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType};
|
||||
|
||||
pub trait TensorValueTypeMarker: ValueTypeMarker {}
|
||||
pub trait TensorValueTypeMarker: ValueTypeMarker {
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DynTensorValueType;
|
||||
impl ValueTypeMarker for DynTensorValueType {}
|
||||
impl TensorValueTypeMarker for DynTensorValueType {}
|
||||
impl ValueTypeMarker for DynTensorValueType {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl TensorValueTypeMarker for DynTensorValueType {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TensorValueType<T: IntoTensorElementType + Debug>(PhantomData<T>);
|
||||
impl<T: IntoTensorElementType + Debug> ValueTypeMarker for TensorValueType<T> {}
|
||||
impl<T: IntoTensorElementType + Debug> TensorValueTypeMarker for TensorValueType<T> {}
|
||||
impl<T: IntoTensorElementType + Debug> ValueTypeMarker for TensorValueType<T> {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl<T: IntoTensorElementType + Debug> TensorValueTypeMarker for TensorValueType<T> {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
/// A tensor [`Value`] whose data type is unknown.
|
||||
pub type DynTensor = Value<DynTensorValueType>;
|
||||
/// A strongly-typed tensor [`Value`].
|
||||
pub type Tensor<T> = Value<TensorValueType<T>>;
|
||||
|
||||
/// A reference to a tensor [`Value`] whose data type is unknown.
|
||||
pub type DynTensorRef<'v> = ValueRef<'v, DynTensorValueType>;
|
||||
/// A mutable reference to a tensor [`Value`] whose data type is unknown.
|
||||
pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>;
|
||||
/// A reference to a strongly-typed tensor [`Value`].
|
||||
pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType<T>>;
|
||||
/// A mutable reference to a strongly-typed tensor [`Value`].
|
||||
pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType<T>>;
|
||||
|
||||
impl DowncastableTarget for DynTensorValueType {
|
||||
fn can_downcast(dtype: &ValueType) -> bool {
|
||||
matches!(dtype, ValueType::Tensor { .. })
|
||||
}
|
||||
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
/// Returns a mutable pointer to the tensor's data.
|
||||
///
|
||||
/// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a
|
||||
/// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be
|
||||
/// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access.
|
||||
/// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before
|
||||
/// accessing it.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut tensor = Tensor::<i64>::from_array((vec![5], vec![0, 1, 2, 3, 4]))?;
|
||||
/// let ptr = tensor.data_ptr_mut()?.cast::<i64>();
|
||||
/// unsafe {
|
||||
/// *ptr.add(3) = 42;
|
||||
/// };
|
||||
///
|
||||
/// let (_, extracted) = tensor.extract_raw_tensor();
|
||||
/// assert_eq!(&extracted, &[0, 1, 2, 42, 4]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn data_ptr_mut(&mut self) -> Result<*mut ort_sys::c_void> {
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)];
|
||||
Ok(buffer_ptr)
|
||||
}
|
||||
|
||||
/// Returns a pointer to the tensor's data.
|
||||
/// Returns an immutable pointer to the tensor's underlying data.
|
||||
///
|
||||
/// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a
|
||||
/// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be
|
||||
/// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access.
|
||||
/// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before
|
||||
/// accessing it.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let tensor = Tensor::<i64>::from_array((vec![5], vec![0, 1, 2, 3, 4]))?;
|
||||
/// let ptr = tensor.data_ptr()?.cast::<i64>();
|
||||
/// assert_eq!(unsafe { *ptr.add(3) }, 3);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn data_ptr(&self) -> Result<*const ort_sys::c_void> {
|
||||
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)];
|
||||
@@ -53,6 +108,26 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
}
|
||||
|
||||
/// Returns information about the device this tensor is allocated on.
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType, Session, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let tensor = Tensor::<f32>::new(&Allocator::default(), [1, 3, 224, 224])?;
|
||||
/// // Tensors are allocated on CPU by default.
|
||||
/// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CPU);
|
||||
///
|
||||
/// # if false {
|
||||
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
|
||||
/// let cuda_allocator = Allocator::new(
|
||||
/// &session,
|
||||
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
|
||||
/// )?;
|
||||
/// let tensor = Tensor::<f32>::new(&cuda_allocator, [1, 3, 224, 224])?;
|
||||
/// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CUDA);
|
||||
/// # }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn memory_info(&self) -> Result<MemoryInfo> {
|
||||
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut();
|
||||
ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr) -> Error::GetTensorMemoryInfo; nonNull(memory_info_ptr)];
|
||||
@@ -62,29 +137,68 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
|
||||
|
||||
impl<T: IntoTensorElementType + Debug> Tensor<T> {
|
||||
/// Converts from a strongly-typed [`Tensor<T>`] to a type-erased [`DynTensor`].
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, DynTensor, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let tensor = Tensor::<f32>::new(&Allocator::default(), [1, 3, 224, 224])?;
|
||||
/// let tensor_dyn = tensor.upcast();
|
||||
/// assert!(tensor_dyn.try_extract_raw_tensor::<f32>().is_ok());
|
||||
/// assert!(tensor_dyn.try_extract_raw_tensor::<i64>().is_err());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn upcast(self) -> DynTensor {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
}
|
||||
|
||||
/// Converts from a strongly-typed [`Tensor<T>`] to a reference to a type-erased [`DynTensor`].
|
||||
/// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor<T>`].
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, DynTensor, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let tensor = Tensor::<f32>::new(&Allocator::default(), [1, 3, 224, 224])?;
|
||||
/// let tensor_dyn = tensor.upcast_ref();
|
||||
///
|
||||
/// let (_, original_extract) = tensor.extract_raw_tensor();
|
||||
/// let (_, ref_extract) = tensor_dyn.try_extract_raw_tensor::<f32>()?;
|
||||
/// assert_eq!(original_extract, ref_extract);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn upcast_ref(&self) -> DynTensorRef {
|
||||
DynTensorRef::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Converts from a strongly-typed [`Tensor<T>`] to a mutable reference to a type-erased [`DynTensor`].
|
||||
///
|
||||
/// ```
|
||||
/// # use ort::{Allocator, DynTensor, Tensor};
|
||||
/// # fn main() -> ort::Result<()> {
|
||||
/// let mut tensor = Tensor::<i64>::from_array((vec![5], vec![1, 2, 3, 4, 5]))?;
|
||||
/// let mut tensor_dyn = tensor.upcast_mut();
|
||||
///
|
||||
/// let (_, mut_view) = tensor_dyn.try_extract_raw_tensor_mut::<i64>()?;
|
||||
/// mut_view[3] = 0;
|
||||
///
|
||||
/// let (_, original_view) = tensor.extract_raw_tensor();
|
||||
/// assert_eq!(original_view, &[1, 2, 3, 0, 5]);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn upcast_mut(&mut self) -> DynTensorRefMut {
|
||||
DynTensorRefMut::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -97,6 +211,8 @@ impl<T: IntoTensorElementType + Debug> DowncastableTarget for TensorValueType<T>
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
impl<T: IntoTensorElementType + Debug> From<Value<TensorValueType<T>>> for DynValue {
|
||||
@@ -113,6 +229,17 @@ impl From<Value<DynTensorValueType>> for DynValue {
|
||||
impl<T: IntoTensorElementType + Clone + Debug, const N: usize> Index<[i64; N]> for Tensor<T> {
|
||||
type Output = T;
|
||||
fn index(&self, index: [i64; N]) -> &Self::Output {
|
||||
// Interestingly, the `TensorAt` API doesn't check if the tensor is on CPU, so we have to perform the check ourselves.
|
||||
if !self
|
||||
.memory_info()
|
||||
.expect("could not retrieve tensor memory info")
|
||||
.allocation_device()
|
||||
.expect("could not retrieve tensor allocation device")
|
||||
.is_cpu_accessible()
|
||||
{
|
||||
panic!("Cannot directly index a tensor which is not allocated on the CPU.");
|
||||
}
|
||||
|
||||
let mut out: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")];
|
||||
unsafe { &*out.cast::<T>() }
|
||||
@@ -120,25 +247,46 @@ impl<T: IntoTensorElementType + Clone + Debug, const N: usize> Index<[i64; N]> f
|
||||
}
|
||||
impl<T: IntoTensorElementType + Clone + Debug, const N: usize> IndexMut<[i64; N]> for Tensor<T> {
|
||||
fn index_mut(&mut self, index: [i64; N]) -> &mut Self::Output {
|
||||
if !self
|
||||
.memory_info()
|
||||
.expect("could not retrieve tensor memory info")
|
||||
.allocation_device()
|
||||
.expect("could not retrieve tensor allocation device")
|
||||
.is_cpu_accessible()
|
||||
{
|
||||
panic!("Cannot directly index a tensor which is not allocated on the CPU.");
|
||||
}
|
||||
|
||||
let mut out: *mut ort_sys::c_void = std::ptr::null_mut();
|
||||
ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")];
|
||||
unsafe { &mut *out.cast::<T>() }
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize {
|
||||
let mut size = 1usize;
|
||||
for dim in shape {
|
||||
if *dim < 0 {
|
||||
return 0;
|
||||
}
|
||||
size *= *dim as usize;
|
||||
}
|
||||
size
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use ndarray::{ArcArray1, Array1, CowArray};
|
||||
|
||||
use crate::{Allocator, DynTensor, TensorElementType, Value, ValueType};
|
||||
use crate::{Tensor, TensorElementType, ValueType};
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "ndarray")]
|
||||
fn test_tensor_value() -> crate::Result<()> {
|
||||
let v: Vec<f32> = vec![1., 2., 3., 4., 5.];
|
||||
let value = Value::from_array(Array1::from_vec(v.clone()))?;
|
||||
let value = Tensor::from_array(Array1::from_vec(v.clone()))?;
|
||||
assert!(value.is_tensor()?);
|
||||
assert_eq!(value.dtype()?.tensor_type(), Some(TensorElementType::Float32));
|
||||
assert_eq!(
|
||||
@@ -163,17 +311,17 @@ mod tests {
|
||||
|
||||
let arc1 = ArcArray1::from_vec(v.clone());
|
||||
let mut arc2 = ArcArray1::clone(&arc1);
|
||||
let value = Value::from_array(&mut arc2)?;
|
||||
let value = Tensor::from_array(&mut arc2)?;
|
||||
drop((arc1, arc2));
|
||||
|
||||
assert_eq!(value.extract_raw_tensor().1, &v);
|
||||
|
||||
let cow = CowArray::from(Array1::from_vec(v.clone()));
|
||||
let value = Value::from_array(&cow)?;
|
||||
let value = Tensor::from_array(&cow)?;
|
||||
assert_eq!(value.extract_raw_tensor().1, &v);
|
||||
|
||||
let owned = Array1::from_vec(v.clone());
|
||||
let value = Value::from_array(owned.view())?;
|
||||
let value = Tensor::from_array(owned.view())?;
|
||||
drop(owned);
|
||||
assert_eq!(value.extract_raw_tensor().1, &v);
|
||||
|
||||
@@ -186,7 +334,7 @@ mod tests {
|
||||
|
||||
let arc = Arc::new(v.clone().into_boxed_slice());
|
||||
let shape = vec![v.len() as i64];
|
||||
let value = Value::from_array((shape, Arc::clone(&arc)))?;
|
||||
let value = Tensor::from_array((shape, Arc::clone(&arc)))?;
|
||||
drop(arc);
|
||||
assert_eq!(value.try_extract_raw_tensor::<f32>()?.1, &v);
|
||||
|
||||
@@ -196,10 +344,9 @@ mod tests {
|
||||
#[test]
|
||||
#[cfg(feature = "ndarray")]
|
||||
fn test_string_tensor_ndarray() -> crate::Result<()> {
|
||||
let allocator = Allocator::default();
|
||||
let v = Array1::from_vec(vec!["hello world".to_string(), "こんにちは世界".to_string()]);
|
||||
|
||||
let value = DynTensor::from_string_array(&allocator, v.view())?;
|
||||
let value = Tensor::from_string_array(v.view())?;
|
||||
let extracted = value.try_extract_string_tensor()?;
|
||||
assert_eq!(extracted, v.into_dyn());
|
||||
|
||||
@@ -208,10 +355,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_string_tensor_raw() -> crate::Result<()> {
|
||||
let allocator = Allocator::default();
|
||||
let v = vec!["hello world".to_string(), "こんにちは世界".to_string()];
|
||||
|
||||
let value = DynTensor::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?;
|
||||
let value = Tensor::from_string_array((vec![v.len() as i64], v.clone().into_boxed_slice()))?;
|
||||
let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?;
|
||||
assert_eq!(extracted_shape, [v.len() as i64]);
|
||||
assert_eq!(extracted_view, v);
|
||||
@@ -224,10 +370,10 @@ mod tests {
|
||||
let v: Vec<f32> = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let shape = [v.len()];
|
||||
let value_arc_box = Value::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?;
|
||||
let value_box = Value::from_array((shape, v.clone().into_boxed_slice()))?;
|
||||
let value_vec = Value::from_array((shape, v.clone()))?;
|
||||
let value_slice = Value::from_array((shape, &v[..]))?;
|
||||
let value_arc_box = Tensor::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?;
|
||||
let value_box = Tensor::from_array((shape, v.clone().into_boxed_slice()))?;
|
||||
let value_vec = Tensor::from_array((shape, v.clone()))?;
|
||||
let value_slice = Tensor::from_array((shape, &v[..]))?;
|
||||
|
||||
assert_eq!(value_arc_box.extract_raw_tensor().1, &v);
|
||||
assert_eq!(value_box.extract_raw_tensor().1, &v);
|
||||
|
||||
@@ -166,6 +166,14 @@ pub(crate) enum ValueInner {
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueInner {
|
||||
pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue {
|
||||
match self {
|
||||
ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A temporary version of a [`Value`] with a lifetime specifier.
|
||||
#[derive(Debug)]
|
||||
pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
|
||||
@@ -277,8 +285,8 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> {
|
||||
/// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`]
|
||||
#[derive(Debug)]
|
||||
pub struct Value<Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
|
||||
inner: ValueInner,
|
||||
_markers: PhantomData<Type>
|
||||
pub(crate) inner: Arc<ValueInner>,
|
||||
pub(crate) _markers: PhantomData<Type>
|
||||
}
|
||||
|
||||
/// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`].
|
||||
@@ -291,11 +299,15 @@ pub type DynValue = Value<DynValueTypeMarker>;
|
||||
///
|
||||
/// For example, [`Tensor::try_extract_tensor`] can only be used on [`Value`]s with the [`TensorValueTypeMarker`] (which
|
||||
/// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s.
|
||||
pub trait ValueTypeMarker: Debug {}
|
||||
pub trait ValueTypeMarker: Debug {
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
/// Represents a type that a [`DynValue`] can be downcast to.
|
||||
pub trait DowncastableTarget: ValueTypeMarker {
|
||||
fn can_downcast(dtype: &ValueType) -> bool;
|
||||
|
||||
crate::private_trait!();
|
||||
}
|
||||
|
||||
// this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence`
|
||||
@@ -303,15 +315,25 @@ impl DowncastableTarget for DynValueTypeMarker {
|
||||
fn can_downcast(_: &ValueType) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
/// The dynamic type marker, used for values which can be of any type.
|
||||
#[derive(Debug)]
|
||||
pub struct DynValueTypeMarker;
|
||||
impl ValueTypeMarker for DynValueTypeMarker {}
|
||||
impl MapValueTypeMarker for DynValueTypeMarker {}
|
||||
impl SequenceValueTypeMarker for DynValueTypeMarker {}
|
||||
impl TensorValueTypeMarker for DynValueTypeMarker {}
|
||||
impl ValueTypeMarker for DynValueTypeMarker {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl MapValueTypeMarker for DynValueTypeMarker {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl SequenceValueTypeMarker for DynValueTypeMarker {
|
||||
crate::private_impl!();
|
||||
}
|
||||
impl TensorValueTypeMarker for DynValueTypeMarker {
|
||||
crate::private_impl!();
|
||||
}
|
||||
|
||||
unsafe impl Send for Value {}
|
||||
|
||||
@@ -350,7 +372,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
///
|
||||
/// If the value belongs to a session (i.e. if it is returned from [`crate::Session::run`] or
|
||||
/// [`crate::IoBinding::run`]), you must provide the [`SharedSessionInner`] (acquired from
|
||||
/// [`crate::Session::inner`]). This ensures the session is not dropped until the value is.
|
||||
/// [`crate::Session::inner`]). This ensures the session is not dropped until any values owned by it is.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
@@ -359,7 +381,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
#[must_use]
|
||||
pub unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
|
||||
Value {
|
||||
inner: ValueInner::CppOwned { ptr, drop: true, _session: session },
|
||||
inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }),
|
||||
_markers: PhantomData
|
||||
}
|
||||
}
|
||||
@@ -369,16 +391,14 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
#[must_use]
|
||||
pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
|
||||
Value {
|
||||
inner: ValueInner::CppOwned { ptr, drop: false, _session: session },
|
||||
inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }),
|
||||
_markers: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the underlying [`ort_sys::OrtValue`] pointer.
|
||||
pub fn ptr(&self) -> *mut ort_sys::OrtValue {
|
||||
match &self.inner {
|
||||
ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr()
|
||||
}
|
||||
self.inner.ptr()
|
||||
}
|
||||
|
||||
/// Create a view of this value's data.
|
||||
@@ -386,7 +406,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
ValueRef::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -396,7 +416,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
|
||||
ValueRefMut::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -442,7 +462,7 @@ impl Value<DynValueTypeMarker> {
|
||||
Ok(ValueRef::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
}))
|
||||
} else {
|
||||
@@ -450,7 +470,7 @@ impl Value<DynValueTypeMarker> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed
|
||||
/// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed
|
||||
/// mutable-reference variant, like [`TensorRefMut<T>`].
|
||||
#[inline]
|
||||
pub fn downcast_mut<OtherType: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(&mut self) -> Result<ValueRefMut<'_, OtherType>> {
|
||||
@@ -459,7 +479,7 @@ impl Value<DynValueTypeMarker> {
|
||||
Ok(ValueRefMut::new(unsafe {
|
||||
Value::from_ptr_nodrop(
|
||||
NonNull::new_unchecked(self.ptr()),
|
||||
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
|
||||
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
|
||||
)
|
||||
}))
|
||||
} else {
|
||||
@@ -468,17 +488,17 @@ impl Value<DynValueTypeMarker> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type: ValueTypeMarker + ?Sized> Drop for Value<Type> {
|
||||
impl Drop for ValueInner {
|
||||
fn drop(&mut self) {
|
||||
let ptr = self.ptr();
|
||||
tracing::trace!(
|
||||
"dropping {} value at {ptr:p}",
|
||||
match &self.inner {
|
||||
match self {
|
||||
ValueInner::RustOwned { .. } => "rust-owned",
|
||||
ValueInner::CppOwned { .. } => "cpp-owned"
|
||||
}
|
||||
);
|
||||
if !matches!(&self.inner, ValueInner::CppOwned { drop: false, .. }) {
|
||||
if !matches!(self, ValueInner::CppOwned { drop: false, .. }) {
|
||||
ortsys![unsafe ReleaseValue(ptr)];
|
||||
}
|
||||
}
|
||||
|
||||
14
src/wasm.rs
14
src/wasm.rs
@@ -1,6 +1,6 @@
|
||||
//! Utilities for using `ort` in WebAssembly.
|
||||
//!
|
||||
//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs:
|
||||
//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs in WASM:
|
||||
//! ```
|
||||
//! # use ort::Session;
|
||||
//! # static MODEL_BYTES: &[u8] = include_bytes!("../tests/data/upsample.ort");
|
||||
@@ -223,12 +223,12 @@ mod emscripten_shims {
|
||||
#[no_mangle]
|
||||
#[export_name = "_initialize"]
|
||||
pub fn initialize() {
|
||||
// No idea what the hell this does, but the presence of an `_initialize` function prevents the linker from calling
|
||||
// `__wasm_call_ctors` at the top of every function - including the functions `wasm-bindgen` interprets to generate
|
||||
// JS glue code. The `__wasm_call_ctors` call was calling complex functions that the interpreter isn't equipped to
|
||||
// handle, which was preventing wbg from outputting anything. I don't know what specific constructors this is calling,
|
||||
// and most basic ONNX Runtime APIs *do* work without calling this, but we encourage the user to perform this
|
||||
// initialization at program start anyways to be safe.
|
||||
// The presence of an `_initialize` function prevents the linker from calling `__wasm_call_ctors` at the top of every
|
||||
// function - including the functions `wasm-bindgen` interprets to generate JS glue code. `__wasm_call_ctors` calls
|
||||
// complex functions that wbg's interpreter isn't equipped to handle, which was preventing wbg from outputting
|
||||
// anything.
|
||||
// I'm not entirely sure what `__wasm_call_ctors` is initializing, but it seems to have something to do with C++
|
||||
// vtables, and it's crucial for proper operation.
|
||||
extern "C" {
|
||||
fn __wasm_call_ctors();
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use std::path::Path;
|
||||
|
||||
use ndarray::{ArrayD, IxDyn};
|
||||
use ort::{inputs, DynTensor, GraphOptimizationLevel, Session};
|
||||
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
@@ -22,7 +22,7 @@ fn vectorizer() -> ort::Result<()> {
|
||||
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
|
||||
|
||||
// Just one input
|
||||
let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?;
|
||||
let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?;
|
||||
|
||||
// Perform the inference
|
||||
let outputs = session.run(input_tensor_values)?;
|
||||
|
||||
Reference in New Issue
Block a user