refactor: usability

aka The Cleanening, part 2

- Add clearer documentation and examples for more things.
- Rework string tensors by introducing `PrimitiveTensorElementType` for primitive (i.e. f32) types, and again re-implementing `IntoTensorElementType` for `String`. This allows string tensors to be used via `Tensor<String>` instead of exclusively via `DynTensor`. Additionally, string tensors no longer require an `Allocator` to be created (which didn't make sense, since string data in Rust can only ever be stored on the CPU anyway). This also now applies to `Map`s, since their data also needed to be on the CPU anyway. (`Sequence`s are currently unaffected because I think a custom allocator could be useful for them?)
- Rework the `IoBinding` interface, and add an example clarifying the intended usage of it (ref #209). Thanks to AAce from the pyke Discord for pointing out the mutability issue in the old interface, which should be addressed now.
- Refactor `OperatorDomain::add` from the slightly-nicer-looking-but-more-confusing `fn<T>(t: T)` to just `fn<T>()` to further enforce the fact that `Operator`s are zero-sized.
- Maps can now have `String` keys.
- Remove some unused errors.
This commit is contained in:
Carson M.
2024-06-21 15:37:39 -05:00
parent 19d66de302
commit c64b8ea990
20 changed files with 1090 additions and 544 deletions

View File

@@ -82,7 +82,7 @@ impl Kernel for CustomOpTwoKernel {
fn main() -> ort::Result<()> {
let session = Session::builder()?
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
.with_operators(OperatorDomain::new("test.customop")?.add::<CustomOpOne>()?.add::<CustomOpTwo>()?)?
.commit_from_file("tests/data/custom_op_test.onnx")?;
let values = session.run(ort::inputs![Array2::<f32>::zeros((3, 5)), Array2::<f32>::ones((3, 5))]?)?;

View File

@@ -34,7 +34,7 @@ pub struct Environment {
}
impl Environment {
/// Loads the underlying [`ort_sys::OrtEnv`] pointer.
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
self.env_ptr.load(Ordering::Relaxed)
}
@@ -52,13 +52,14 @@ impl Drop for Environment {
}
}
/// Gets a reference to the global environment, creating one if an environment has been committed yet.
/// Gets a reference to the global environment, creating one if an environment has not been
/// [`commit`](EnvironmentBuilder::commit)ted yet.
pub fn get_environment() -> Result<&'static Arc<Environment>> {
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
Ok(c)
} else {
debug!("Environment not yet initialized, creating a new one");
EnvironmentBuilder::default().commit()?;
EnvironmentBuilder::new().commit()?;
Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() })
}
@@ -72,7 +73,7 @@ pub struct EnvironmentGlobalThreadPoolOptions {
pub intra_op_thread_affinity: Option<String>
}
/// Struct used to build an `Environment`.
/// Struct used to build an [`Environment`]; see [`crate::init`].
pub struct EnvironmentBuilder {
name: String,
telemetry: bool,
@@ -80,8 +81,8 @@ pub struct EnvironmentBuilder {
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
}
impl Default for EnvironmentBuilder {
fn default() -> Self {
impl EnvironmentBuilder {
pub(crate) fn new() -> Self {
EnvironmentBuilder {
name: "default".to_string(),
telemetry: true,
@@ -89,11 +90,9 @@ impl Default for EnvironmentBuilder {
global_thread_pool_options: None
}
}
}
impl EnvironmentBuilder {
/// Configure the environment with a given name for logging purposes.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_name<S>(mut self, name: S) -> Self
where
S: Into<String>
@@ -102,7 +101,17 @@ impl EnvironmentBuilder {
self
}
#[must_use]
/// Enable or disable sending telemetry events to Microsoft.
///
/// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled.
/// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled.
///
/// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.0/onnxruntime/core/platform/windows/telemetry.cc).
/// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or
/// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names,
/// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to
/// better understand how customers use ONNX Runtime and where performance can be improved.
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_telemetry(mut self, enable: bool) -> Self {
self.telemetry = enable;
self
@@ -116,14 +125,14 @@ impl EnvironmentBuilder {
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
/// feature enabled will emit a warning.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self {
self.execution_providers = execution_providers.as_ref().to_vec();
self
}
/// Enables the global thread pool for this environment.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self {
self.global_thread_pool_options = Some(options);
self
@@ -158,14 +167,17 @@ impl EnvironmentBuilder {
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
}
ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
ortsys![
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
thread_options,
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
) -> Error::CreateEnvironment;
nonNull(env_ptr)
];
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
(env_ptr, true)
} else {
@@ -174,13 +186,16 @@ impl EnvironmentBuilder {
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
ortsys![unsafe CreateEnvWithCustomLogger(
ortsys![
unsafe CreateEnvWithCustomLogger(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
) -> Error::CreateEnvironment;
nonNull(env_ptr)
];
(env_ptr, false)
};
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");
@@ -205,15 +220,25 @@ impl EnvironmentBuilder {
/// Creates an ONNX Runtime environment.
///
/// ```
/// # use ort::CUDAExecutionProvider;
/// # fn main() -> ort::Result<()> {
/// ort::init()
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
/// .commit()?;
/// # Ok(())
/// # }
/// ```
///
/// # Notes
/// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a
/// default environment will be created.
/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it.
/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it.
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn init() -> EnvironmentBuilder {
EnvironmentBuilder::default()
EnvironmentBuilder::new()
}
/// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`)
@@ -221,15 +246,26 @@ pub fn init() -> EnvironmentBuilder {
///
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
///
/// ```no_run
/// # use ort::CUDAExecutionProvider;
/// # fn main() -> ort::Result<()> {
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
/// ort::init_from(lib_path.join("onnxruntime.dll"))
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
/// .commit()?;
/// # Ok(())
/// # }
/// ```
///
/// # Notes
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
#[cfg(feature = "load-dynamic")]
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
EnvironmentBuilder::default()
EnvironmentBuilder::new()
}
/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct.
@@ -325,7 +361,7 @@ mod tests {
assert!(!is_env_initialized());
assert_eq!(env_ptr(), None);
EnvironmentBuilder::default().with_name("env_is_initialized").commit()?;
EnvironmentBuilder::new().with_name("env_is_initialized").commit()?;
assert!(is_env_initialized());
assert_ne!(env_ptr(), None);
Ok(())

View File

@@ -121,9 +121,6 @@ pub enum Error {
/// Error occurred when filling a tensor with string data
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(ErrorInternal),
/// Error occurred when checking if a value is a tensor
#[error("Failed to check if value is a tensor: {0}")]
FailedTensorCheck(ErrorInternal),
/// Error occurred when getting tensor type and shape
#[error("Failed to get tensor type and shape: {0}")]
GetTensorTypeAndShape(ErrorInternal),
@@ -159,12 +156,6 @@ pub enum Error {
/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models).
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] FetchModelError),
/// Type of input data and the ONNX model do not match.
#[error("Data types do not match: expected {model:?}, got {input:?}")]
NonMatchingDataTypes { input: TensorElementType, model: TensorElementType },
/// Dimensions of input data and the ONNX model do not match.
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
/// File does not exist
#[error("File `{filename:?}` does not exist")]
FileDoesNotExist {
@@ -186,9 +177,6 @@ pub enum Error {
/// ORT pointer should not have been null
#[error("`{0}` should not be a null pointer")]
PointerShouldNotBeNull(&'static str),
/// The runtime type was undefined.
#[error("Undefined tensor element type")]
UndefinedTensorElementType,
/// Could not retrieve model metadata.
#[error("Failed to retrieve model metadata: {0}")]
GetModelMetadata(ErrorInternal),
@@ -208,8 +196,8 @@ pub enum Error {
ExecutionProviderNotRegistered(&'static str),
#[error("Expected tensor to be on CPU in order to get data, but had allocation device `{0}`.")]
TensorNotOnCpu(&'static str),
#[error("String tensors require the session's allocator to be provided through `Value::from_array`.")]
StringTensorRequiresAllocator,
#[error("Cannot extract scalar value from a {0}-dimensional tensor")]
TensorNot0Dimensional(usize),
#[error("Failed to create memory info: {0}")]
CreateMemoryInfo(ErrorInternal),
#[error("Could not get allocation device from `MemoryInfo`: {0}")]
@@ -222,10 +210,10 @@ pub enum Error {
BindInput(ErrorInternal),
#[error("Error when binding output: {0}")]
BindOutput(ErrorInternal),
#[error("Failed to clear IO binding: {0}")]
ClearBinding(ErrorInternal),
#[error("Error when retrieving session outputs from `IoBinding`: {0}")]
GetBoundOutputs(ErrorInternal),
#[error("Cannot use `extract_tensor` on a value that is {0:?}")]
NotTensor(ValueType),
#[error("Cannot use `extract_sequence` on a value that is {0:?}")]
NotSequence(ValueType),
#[error("Cannot use `extract_map` on a value that is {0:?}")]
@@ -252,6 +240,8 @@ pub enum Error {
GetOperatorInput(ErrorInternal),
#[error("Failed to get operator output: {0}")]
GetOperatorOutput(ErrorInternal),
#[error("Failed to retrieve GPU compute stream from kernel context: {0}")]
GetOperatorGPUComputeStream(ErrorInternal),
#[error("{0}")]
CustomError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("String tensors cannot be borrowed as mutable")]
@@ -266,37 +256,20 @@ pub enum Error {
GetDeviceId(ErrorInternal)
}
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
Error::Infallible
impl Error {
/// Wrap a custom, user-provided error in an [`ort::Error`](Error). The resulting error will be the
/// [`Error::CustomError`] variant.
///
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
/// related operation fails.
pub fn wrap<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
Error::CustomError(Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>)
}
}
/// Error used when the input dimensions defined in the model and passed from an inference call do not match.
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
/// Number of inputs from model does not match the number of inputs from inference call.
#[error(
"Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})"
)]
InputsCount {
/// Number of input dimensions used by inference call
inference_input_count: usize,
/// Number of input dimensions defined in model
model_input_count: usize,
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>
},
/// Inputs length from model does not match the expected input from inference call
#[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")]
InputsLength {
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
Error::Infallible
}
}

View File

@@ -1,6 +1,8 @@
use std::{
collections::HashMap,
ffi::CString,
fmt::Debug,
marker::PhantomData,
ptr::{self, NonNull},
sync::Arc
};
@@ -9,24 +11,87 @@ use crate::{
memory::MemoryInfo,
ortsys,
session::{output::SessionOutputs, RunOptions},
value::{Value, ValueRefMut},
Error, Result, Session, ValueTypeMarker
value::{Value, ValueInner},
DynValue, Error, Result, Session, ValueTypeMarker
};
/// Enables binding of session inputs and/or outputs to pre-allocated memory.
///
/// Note that this arrangement is designed to minimize data copies, and to that effect, your memory allocations must
/// match what is expected by the model, whether you run on CPU or GPU. Data will still be copied if the
/// pre-allocated memory location does not match the one expected by the model. However, copies with `IoBinding`s are
/// only done once, at the time of the binding, not at run time. This means, that if your input data required a copy,
/// your further input modifications would not be seen by ONNX Runtime unless you rebind it, even if it is the same
/// buffer. If your scenario requires that the data is copied, `IoBinding` may not be the best match for your use case.
/// The fact that data copy is not made during runtime may also have performance implications.
/// [`IoBinding`] minimizes copies between a device (like a GPU) and the host (CPU) by allowing the user to bind a
/// certain input/output to a pre-allocated value on a specific device.
///
/// [`IoBinding`] is most suitable for:
/// - An ensemble of models in which the output from one model is the input of another and does not need to pass through
/// the CPU to perform additional processing.
/// - Situations where the output should stay on a device (e.g. to perform additional processing with CUDA).
/// - Diffusion models, for instance, that accept an unchanging embedding for conditioning.
///
/// [`IoBinding`] will not provide any meaningful benefit for:
/// - Models where every input changes with each invocation, such as a causal language model or object recognition
/// model.
/// - Pipelines that go straight from CPU -> GPU -> CPU.
///
/// # Example
/// A diffusion model which takes a text condition input.
///
/// ```no_run
/// # use ort::{Allocator, AllocatorType, AllocationDevice, CUDAExecutionProvider, MemoryInfo, MemoryType, Session, Tensor, IoBinding};
/// # fn main() -> ort::Result<()> {
/// let text_encoder = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .commit_from_file("text_encoder.onnx")?;
/// let unet = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .commit_from_file("unet.onnx")?;
///
/// let text_condition = text_encoder
/// .run(ort::inputs![Tensor::<i64>::from_array((
/// vec![27],
/// vec![
/// 23763, 15460, 473, 68, 312, 265, 17463, 4098, 304, 1077, 283, 198, 7676, 5976, 272, 285, 3609, 435,
/// 21680, 321, 265, 300, 1689, 64, 285, 4763, 64
/// ]
/// ))?]?)?
/// .remove("output0")
/// .unwrap();
///
/// let input_allocator = Allocator::new(
/// &unet,
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
/// )?;
/// let mut latents = Tensor::<f32>::new(&input_allocator, [1, 4, 64, 64])?;
///
/// let mut io_binding = unet.create_binding()?;
/// io_binding.bind_input("condition", &text_condition)?;
///
/// let output_allocator = Allocator::new(
/// &unet,
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUOutput)?
/// )?;
/// io_binding.bind_output("noise_pred", Tensor::<f32>::new(&output_allocator, [1, 4, 64, 64])?)?;
///
/// for _ in 0..20 {
/// io_binding.bind_input("latents", &latents)?;
/// let noise_pred = io_binding.run()?.remove("noise_pred").unwrap();
///
/// let mut latents = latents.extract_tensor_mut();
/// latents += &noise_pred.try_extract_tensor::<f32>()?;
/// }
/// # Ok(())
/// # }
/// ```
///
/// [`IoBinding`] may provide a decent speedup in this example since the `condition` tensor is unchanging between runs.
/// If we were to use normal session inference, the `condition` tensor would be needlessly copied with each invocation
/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition`
/// tensor is only copied to the device once instead of 20 times.
#[derive(Debug)]
pub struct IoBinding<'s> {
pub(crate) ptr: NonNull<ort_sys::OrtIoBinding>,
session: &'s Session,
output_names: Vec<String>
held_inputs: HashMap<String, Arc<ValueInner>>,
output_names: Vec<String>,
output_values: HashMap<String, DynValue>
}
impl<'s> IoBinding<'s> {
@@ -36,25 +101,47 @@ impl<'s> IoBinding<'s> {
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(ptr) },
session,
output_names: Vec::new()
held_inputs: HashMap::new(),
output_names: Vec::new(),
output_values: HashMap::new()
})
}
/// Bind a [`Value`] to a session input.
pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &'i mut Value<T>) -> Result<ValueRefMut<'i, T>> {
///
/// Upon invocation, the value's data will be copied to the device the session is allocated on. The copied data will
/// be used as an input (specified by `name`) in all future invocations of [`IoBinding::run`] until the input is
/// overridden (by calling [`IoBinding::bind_input`] again) or until all inputs are cleared (via
/// [`IoBinding::clear_inputs`] or [`IoBinding::clear`]).
///
/// The data is only copied **once**, immediately upon invocation of this function. Any changes to the given
/// value afterwards will not affect the data seen by the session until the value is re-bound. Subsequent re-binds
/// will still copy data, hence why [`IoBinding`] is really only suitable when one or more inputs do not change
/// between runs.
pub fn bind_input<T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
let name = name.as_ref();
let cname = CString::new(name)?;
ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput];
Ok(ort_value.view_mut())
self.held_inputs.insert(name.to_string(), Arc::clone(&ort_value.inner));
Ok(())
}
/// Bind a session output to a pre-allocated [`Value`].
pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &'o mut Value<T>) -> Result<ValueRefMut<'o, T>> {
///
/// This allows for the pre-allocation and reuse of memory in the session output (see [`crate::Tensor::new`]). Any
/// subsequent runs via [`IoBinding::run`] will reuse the same tensor to store the output instead of creating a new
/// one each time.
///
/// The output will be accessible in the value returned by [`IoBinding::run`], under the name specified by `name`.
pub fn bind_output<T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: Value<T>) -> Result<()> {
let name = name.as_ref();
let cname = CString::new(name)?;
ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput];
self.output_names.push(name.to_string());
Ok(ort_value.view_mut())
// Clear the old bound output if we have any.
drop(self.output_values.remove(name));
self.output_values.insert(name.to_string(), ort_value.into_dyn());
Ok(())
}
/// Bind a session output to a device which is specified by `mem_info`.
@@ -66,15 +153,35 @@ impl<'s> IoBinding<'s> {
Ok(())
}
pub fn run<'i: 's>(&'i self) -> Result<SessionOutputs<'s>> {
/// Clears all bound inputs specified by [`IoBinding::bind_input`].
pub fn clear_inputs(&mut self) {
ortsys![unsafe ClearBoundInputs(self.ptr.as_ptr())];
drop(self.held_inputs.drain());
}
/// Clears all bound outputs specified by [`IoBinding::bind_output`] or [`IoBinding::bind_output_to_device`].
pub fn clear_outputs(&mut self) {
ortsys![unsafe ClearBoundOutputs(self.ptr.as_ptr())];
drop(self.output_names.drain(..));
drop(self.output_values.drain());
}
/// Clears both the bound inputs & outputs; equivalent to [`IoBinding::clear_inputs`] followed by
/// [`IoBinding::clear_outputs`].
pub fn clear(&mut self) {
self.clear_inputs();
self.clear_outputs();
}
/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
pub fn run(&mut self) -> Result<SessionOutputs<'_>> {
self.run_inner(None)
}
pub fn run_with_options<'i: 's>(&'i self, run_options: Arc<RunOptions>) -> Result<SessionOutputs<'s>> {
/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
pub fn run_with_options(&mut self, run_options: Arc<RunOptions>) -> Result<SessionOutputs<'_>> {
self.run_inner(Some(run_options))
}
fn run_inner<'i: 's>(&'i self, run_options: Option<Arc<RunOptions>>) -> Result<SessionOutputs<'s>> {
fn run_inner(&mut self, run_options: Option<Arc<RunOptions>>) -> Result<SessionOutputs<'_>> {
let run_options_ptr = if let Some(run_options) = run_options {
run_options.run_options_ptr.as_ptr()
} else {
@@ -82,6 +189,7 @@ impl<'s> IoBinding<'s> {
};
ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr()) -> Error::SessionRunWithIoBinding];
let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc<ValueInner>> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect();
let mut count = self.output_names.len() as ort_sys::size_t;
if count > 0 {
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
@@ -91,10 +199,17 @@ impl<'s> IoBinding<'s> {
let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() }
.into_iter()
.map(|v| unsafe {
Value::from_ptr(
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
Some(Arc::clone(&self.session.inner))
)
if let Some(inner) = owned_ptrs.get(&v) {
DynValue {
inner: Arc::clone(*inner),
_markers: PhantomData
}
} else {
DynValue::from_ptr(
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
Some(Arc::clone(&self.session.inner))
)
}
});
// output values will be freed when the `Value`s in `SessionOutputs` drop

View File

@@ -65,7 +65,7 @@ pub use self::session::{
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub use self::tensor::ArrayExtensions;
pub use self::tensor::{IntoTensorElementType, TensorElementType};
pub use self::tensor::{IntoTensorElementType, Utf8Data, PrimitiveTensorElementType, TensorElementType};
pub use self::value::{
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,
@@ -143,6 +143,23 @@ pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::ne
/// May panic if:
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
///
/// # Examples
/// The primary (public-facing) use case for this function is accessing APIs that do not have a corresponding safe
/// implementation in `ort`. For example, [`GetBuildInfoString`](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0a7dba37b0017c0ef3a0ab4e266a967d):
///
/// ```
/// # use std::ffi::CStr;
/// # fn main() -> ort::Result<()> {
/// let api = ort::api().as_ptr();
/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) };
/// println!("{}", build_info.to_string_lossy());
/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED
/// # Ok(())
/// # }
/// ```
///
/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html).
pub fn api() -> NonNull<ort_sys::OrtApi> {
unsafe {
NonNull::new_unchecked(
@@ -252,6 +269,26 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {
.map_err(Error::FfiStringConversion)
}
pub(crate) struct PrivateTraitMarker;
macro_rules! private_trait {
() => {
#[doc(hidden)]
#[allow(private_interfaces)]
fn _private() -> crate::PrivateTraitMarker;
};
}
macro_rules! private_impl {
() => {
#[allow(private_interfaces)]
fn _private() -> crate::PrivateTraitMarker {
crate::PrivateTraitMarker
}
};
}
pub(crate) use private_impl;
pub(crate) use private_trait;
#[cfg(test)]
mod test {
use std::ffi::CString;

View File

@@ -1,20 +1,75 @@
use std::{
ffi::{c_char, c_int, CString},
ptr::NonNull
ptr::NonNull,
sync::Arc
};
use super::{
error::{Error, Result},
ortsys
};
use crate::{char_p_to_string, error::status_to_result, Session};
use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner};
/// An ONNX Runtime allocator, used to manage the allocation of [`crate::Value`]s.
/// A device allocator used to manage the allocation of [`crate::Value`]s.
///
/// # Direct allocation
/// [`Allocator`] can be used to directly allocate device memory. This can be useful if you have a
/// postprocessing step that runs on the GPU.
/// ```no_run
/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
/// )?;
///
/// let mut tensor = Tensor::<f32>::new(&allocator, [1, 3, 224, 224])?;
/// // Here, `data_ptr` is a pointer to **device memory** inaccessible to the CPU; we'll need another crate, like
/// // `cudarc`, to access it.
/// let data_ptr = tensor.data_ptr_mut()?;
/// # Ok(())
/// # }
/// ```
///
/// Note that `ort` does not facilitate the transfer of data between host & device outside of session inputs &
/// outputs; you'll need to use a separate crate for that, like [`cudarc`](https://crates.io/crates/cudarc) for CUDA.
///
/// # Pinned allocation
/// Memory allocated on the host CPU is often *pageable* and may reside on the disk (swap memory). Transferring
/// pageable memory to another device is slow because the device has to go through the CPU to access the
/// memory. Many execution providers thus provide a "pinned" allocator type, which allocates *unpaged* CPU memory
/// that the device is able to access directly, bypassing the CPU and allowing for faster host-to-device data
/// transfer.
///
/// If you create a session with a device allocator that supports pinned memory, like CUDA or ROCm, you can create
/// an allocator for that session, and use it to allocate tensors with faster pinned memory:
/// ```no_run
/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
/// )?;
///
/// // Create a tensor with our pinned allocator.
/// let mut tensor = Tensor::<f32>::new(&allocator, [1, 3, 224, 224])?;
/// let data = tensor.extract_tensor_mut();
/// // ...fill `data` with data...
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
pub struct Allocator {
pub(crate) ptr: NonNull<ort_sys::OrtAllocator>,
/// The 'default' CPU allocator, provided by `GetAllocatorWithDefaultOptions` and implemented by
/// [`Allocator::default`], should **not** be released, so this field marks whether or not we should call
/// `ReleaseAllocator` on drop.
is_default: bool,
_info: Option<MemoryInfo>
_info: Option<MemoryInfo>,
/// Hold a reference to the session if this allocator is tied to one.
_session_inner: Option<Arc<SharedSessionInner>>
}
impl Allocator {
@@ -22,47 +77,46 @@ impl Allocator {
Allocator {
ptr: NonNull::new_unchecked(ptr),
is_default: false,
// currently, this function is only ever used in session creation, where we call `CreateAllocator` manually and store the allocator resulting from
// this function in the `SharedSessionInner` - we don't need to hold onto the session, because the session is holding onto us.
_session_inner: None,
_info: None
}
}
/// Frees an object allocated by this allocator, given the object's C pointer.
pub(crate) unsafe fn free<T>(&self, ptr: *mut T) {
self.ptr.as_ref().Free.unwrap_or_else(|| unreachable!("Allocator method `Free` is null"))(self.ptr.as_ptr(), ptr.cast());
}
/// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the
/// [`MemoryInfo`].
///
/// For example, to create an allocator to allocate pinned memory for CUDA:
/// ```no_run
/// # use ort::{Allocator, Session, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
/// )?;
/// # Ok(())
/// # }
/// ```
pub fn new(session: &Session, memory_info: MemoryInfo) -> Result<Self> {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
is_default: false,
_session_inner: Some(session.inner()),
_info: Some(memory_info)
})
}
}
impl Default for Allocator {
/// Returns the default CPU allocator; equivalent to `MemoryInfo::new(AllocationDevice::CPU, 0,
/// AllocatorType::Device, MemoryType::Default)`.
///
/// The allocator returned by this function is actually shared across all invocations (though this behavior is
/// transparent to the user).
fn default() -> Self {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
status_to_result(ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]).expect("Failed to get default allocator");
Self {
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
is_default: true,
// The default allocator isn't tied to a session.
_session_inner: None,
_info: None
}
}
@@ -70,8 +124,6 @@ impl Default for Allocator {
impl Drop for Allocator {
fn drop(&mut self) {
// per GetAllocatorWithDefaultOptions docs: Returned value should NOT be freed
// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a8dec797ae52ee1a681e4f88be1fb4bb3
if !self.is_default {
ortsys![unsafe ReleaseAllocator(self.ptr.as_ptr())];
}
@@ -81,7 +133,8 @@ impl Drop for Allocator {
/// Represents possible devices that have their own device allocator.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AllocationDevice {
// https://github.com/microsoft/onnxruntime/blob/v1.17.0/include/onnxruntime/core/framework/allocator.h#L43-L53
// https://github.com/microsoft/onnxruntime/blob/v1.18.0/include/onnxruntime/core/framework/allocator.h#L43-L53
// ort will likely never support WebGPU, so I think it's best to leave `WebGPU_Buffer` out entirely to reduce confusion
CPU,
CUDA,
CUDAPinned,
@@ -91,12 +144,10 @@ pub enum AllocationDevice {
HIP,
HIPPinned,
OpenVINOCPU,
OpenVINOGPU,
WebGPUBuffer
OpenVINOGPU
}
impl AllocationDevice {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::CPU => "Cpu",
@@ -108,10 +159,15 @@ impl AllocationDevice {
Self::HIP => "Hip",
Self::HIPPinned => "HipPinned",
Self::OpenVINOCPU => "OpenVINO_CPU",
Self::OpenVINOGPU => "OpenVINO_GPU",
Self::WebGPUBuffer => "WebGPU_Buffer"
Self::OpenVINOGPU => "OpenVINO_GPU"
}
}
/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
/// it could be extracted to an `ndarray` or slice.
pub fn is_cpu_accessible(&self) -> bool {
matches!(self, Self::CPU | Self::CUDAPinned | Self::CANNPinned | Self::HIPPinned | Self::OpenVINOCPU)
}
}
impl TryFrom<String> for AllocationDevice {
@@ -129,14 +185,13 @@ impl TryFrom<String> for AllocationDevice {
"HipPinned" => Ok(AllocationDevice::HIPPinned),
"OpenVINO_CPU" => Ok(AllocationDevice::OpenVINOCPU),
"OpenVINO_GPU" => Ok(AllocationDevice::OpenVINOGPU),
"WebGPUBuffer" => Ok(AllocationDevice::WebGPUBuffer),
_ => Err(value)
}
}
}
/// Execution provider allocator type.
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AllocatorType {
/// Default device-specific allocator.
Device,
@@ -154,11 +209,11 @@ impl From<AllocatorType> for ort_sys::OrtAllocatorType {
}
/// Memory types for allocated memory.
#[derive(Default, Debug, Copy, Clone)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
pub enum MemoryType {
/// Any CPU memory used by non-CPU execution provider.
CPUInput,
/// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED.
/// CPU-accessible memory output by a non-CPU execution provider, i.e. [`AllocatorDevice::CUDAPinned`].
CPUOutput,
/// The default allocator for an execution provider.
#[default]
@@ -190,6 +245,12 @@ impl From<ort_sys::OrtMemType> for MemoryType {
}
}
/// Structure describing a memory location - the device on which the memory resides, the type of allocator (device
/// default, or arena) used, and the type of memory allocated (device-only, or CPU accessible).
///
/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which
/// device value data should reside, and how that data should be accessible with regard to the CPU (if a non-CPU device
/// is requested).
#[derive(Debug)]
pub struct MemoryInfo {
pub(crate) ptr: NonNull<ort_sys::OrtMemoryInfo>,
@@ -197,24 +258,24 @@ pub struct MemoryInfo {
}
impl MemoryInfo {
pub(crate) fn from_raw(ptr: NonNull<ort_sys::OrtMemoryInfo>, should_release: bool) -> Self {
MemoryInfo { ptr, should_release }
}
#[tracing::instrument]
pub fn new_cpu(allocator: AllocatorType, memory_type: MemoryType) -> Result<Self> {
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
ortsys![
unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr) -> Error::CreateMemoryInfo;
nonNull(memory_info_ptr)
];
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) },
should_release: true
})
}
#[tracing::instrument]
/// Creates a [`MemoryInfo`], describing a memory location on a device allocator.
///
/// # Examples
/// `MemoryInfo` can be used to specify the device & memory type used by an [`Allocator`] to allocate tensors.
/// See [`Allocator`] for more information & potential applications.
/// ```no_run
/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
/// )?;
///
/// let mut tensor = Tensor::<f32>::new(&allocator, [1, 3, 224, 224])?;
/// # Ok(())
/// # }
/// ```
pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result<Self> {
let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!());
@@ -229,7 +290,19 @@ impl MemoryInfo {
})
}
pub(crate) fn from_raw(ptr: NonNull<ort_sys::OrtMemoryInfo>, should_release: bool) -> Self {
MemoryInfo { ptr, should_release }
}
/// Returns the [`MemoryType`] described by this struct.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.memory_type()?, MemoryType::Default);
/// # Ok(())
/// # }
/// ```
pub fn memory_type(&self) -> Result<MemoryType> {
let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault;
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetMemoryType];
@@ -237,6 +310,14 @@ impl MemoryInfo {
}
/// Returns the [`AllocatorType`] described by this struct.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.allocator_type()?, AllocatorType::Device);
/// # Ok(())
/// # }
/// ```
pub fn allocator_type(&self) -> Result<AllocatorType> {
let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator;
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetAllocatorType];
@@ -248,6 +329,14 @@ impl MemoryInfo {
}
/// Returns the [`AllocationDevice`] this struct was created with.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.allocation_device()?, AllocationDevice::CPU);
/// # Ok(())
/// # }
/// ```
pub fn allocation_device(&self) -> Result<AllocationDevice> {
let mut name_ptr: *const c_char = std::ptr::null_mut();
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr) -> Error::GetAllocationDevice; nonNull(name_ptr)];
@@ -258,6 +347,14 @@ impl MemoryInfo {
}
/// Returns the ID of the [`AllocationDevice`] described by this struct.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.device_id()?, 0);
/// # Ok(())
/// # }
/// ```
pub fn device_id(&self) -> Result<i32> {
let mut raw: ort_sys::c_int = 0;
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw) -> Error::GetDeviceId];
@@ -266,24 +363,9 @@ impl MemoryInfo {
}
impl Drop for MemoryInfo {
#[tracing::instrument]
fn drop(&mut self) {
if self.should_release {
ortsys![unsafe ReleaseMemoryInfo(self.ptr.as_ptr())];
}
}
}
#[cfg(test)]
mod tests {
use test_log::test;
use super::*;
#[test]
fn create_memory_info() -> crate::Result<()> {
let memory_info = MemoryInfo::new_cpu(AllocatorType::Device, MemoryType::Default)?;
std::mem::drop(memory_info);
Ok(())
}
}

View File

@@ -11,8 +11,7 @@ use super::{
};
use crate::error::IntoStatus;
#[repr(C)]
#[derive(Clone)]
#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later
pub(crate) struct BoundOperator<O: Operator> {
implementation: ort_sys::OrtCustomOp,
name: CString,
@@ -184,7 +183,10 @@ unsafe impl Send for ErasedBoundOperator {}
impl ErasedBoundOperator {
pub(crate) fn new<O: Operator>(bound: BoundOperator<O>) -> Self {
ErasedBoundOperator(NonNull::from(unsafe { &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) }))
ErasedBoundOperator(NonNull::from(unsafe {
// horrible horrible horrible horrible horrible horrible horrible horrible horrible
&mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ())
}))
}
pub(crate) fn op_ptr(&self) -> *mut ort_sys::OrtCustomOp {

View File

@@ -138,4 +138,13 @@ impl KernelContext {
ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx as ort_sys::size_t, shape.as_ptr(), shape.len() as _, &mut value_ptr) -> Error::GetOperatorOutput];
Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) })))
}
/// Returns a pointer to the GPU compute stream (i.e. `cudaStream_t`) used by the execution provider, if this
/// kernel's operator was configured to use said execution provider (see
/// [`super::Operator::execution_provider_type`]).
pub fn compute_stream(&self) -> Result<Option<NonNull<ort_sys::c_void>>> {
let mut stream_ptr: *mut ort_sys::c_void = ptr::null_mut();
ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetOperatorGPUComputeStream];
Ok(NonNull::new(stream_ptr))
}
}

View File

@@ -16,11 +16,26 @@ use crate::{operator::bound::BoundOperator, ortsys, Error, Result};
pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>;
/// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator.
///
/// [`Operator`]s are bound to [`OperatorDomain`]s. Multiple operators can have the same name as long as they have
/// different input/output types, in which case the exact operator will be picked depending on the input/output
/// types. If you want to, for example, define a `Sort` operator that can accept either a single `f32` or `i64` tensor
/// input, you'll need to define 2 separate operators (which can be done via a macro); but both of these
/// [`Operator`] structs can return the same name in [`Operator::name`] so that they are usable as simply
/// `my.domain:Sort` in the graph.
pub trait Operator: Send {
type Kernel: Kernel;
/// Returns the name of the operator.
fn name() -> &'static str;
/// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`.
///
/// If the returned type is not `None`, and the execution provider used by the session matches this operator's
/// EP type, the value will not be copied to the CPU and you may use functions like [`crate::Tensor::data_ptr`] to
/// access the underlying device memory, and [`super::KernelContext::compute_stream`] to access the GPU compute
/// stream.
fn execution_provider_type() -> Option<&'static str> {
None
}
@@ -42,6 +57,7 @@ pub trait Operator: Send {
}
}
/// Dummy type implementing [`Operator`] used by [`ErasedBoundOperator`] to cheat the type system.
struct DummyOperator;
impl Operator for DummyOperator {
@@ -84,7 +100,7 @@ impl OperatorDomain {
}
#[allow(clippy::should_implement_trait)]
pub fn add<O: Operator>(mut self, _operator: O) -> Result<Self> {
pub fn add<O: Operator>(mut self) -> Result<Self> {
let name = O::name();
let bound = BoundOperator::<O>::new(CString::new(name)?, O::execution_provider_type().map(CString::new).transpose()?);

View File

@@ -92,16 +92,15 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<
/// # }
/// ```
///
/// Note that string tensors must be created manually with [`Value::from_string_array`].
/// Note that string tensors must be created manually with [`crate::Tensor::from_string_array`].
///
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ndarray::Array1;
/// # use ort::{GraphOptimizationLevel, Session, Value};
/// # use ort::{GraphOptimizationLevel, Session, Tensor};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
/// let _ = session
/// .run(ort::inputs![Value::from_string_array(session.allocator(), Array1::from_vec(vec!["hello", "world"]))?]?);
/// let _ = session.run(ort::inputs![Tensor::from_string_array(Array1::from_vec(vec!["hello", "world"]))?]?);
/// # Ok(())
/// # }
/// ```

View File

@@ -22,4 +22,4 @@ mod types;
pub use self::ndarray::ArrayExtensions;
#[cfg(feature = "ndarray")]
pub(crate) use self::types::{extract_primitive_array, extract_primitive_array_mut};
pub use self::types::{IntoTensorElementType, TensorElementType, Utf8Data};
pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};

View File

@@ -91,6 +91,12 @@ impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
pub trait IntoTensorElementType {
/// Returns the ONNX tensor element data type corresponding to the given Rust type.
fn into_tensor_element_type() -> TensorElementType;
crate::private_trait!();
}
pub trait PrimitiveTensorElementType: IntoTensorElementType {
crate::private_trait!();
}
macro_rules! impl_type_trait {
@@ -99,6 +105,12 @@ macro_rules! impl_type_trait {
fn into_tensor_element_type() -> TensorElementType {
TensorElementType::$variant
}
crate::private_impl!();
}
impl PrimitiveTensorElementType for $type_ {
crate::private_impl!();
}
};
}
@@ -121,6 +133,14 @@ impl_type_trait!(u64, Uint64);
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
impl_type_trait!(half::bf16, Bfloat16);
impl IntoTensorElementType for String {
fn into_tensor_element_type() -> TensorElementType {
TensorElementType::String
}
crate::private_impl!();
}
/// Adapter for common Rust string types to ONNX strings.
pub trait Utf8Data {
/// Returns the contents of this value as a slice of UTF-8 bytes.

View File

@@ -3,25 +3,39 @@ use std::{
fmt::Debug,
hash::Hash,
marker::PhantomData,
ptr::{self, NonNull}
ptr::{self, NonNull},
sync::Arc
};
use super::{ValueInner, ValueTypeMarker};
use crate::{
memory::Allocator, ortsys, value::impl_tensor::DynTensor, DynValue, Error, IntoTensorElementType, Result, Tensor, Value, ValueRef, ValueRefMut, ValueType
memory::Allocator,
ortsys,
value::impl_tensor::{calculate_tensor_size, DynTensor},
DynValue, Error, IntoTensorElementType, PrimitiveTensorElementType, Result, Tensor, TensorElementType, Value, ValueRef, ValueRefMut, ValueType
};
pub trait MapValueTypeMarker: ValueTypeMarker {}
pub trait MapValueTypeMarker: ValueTypeMarker {
crate::private_trait!();
}
#[derive(Debug)]
pub struct DynMapValueType;
impl ValueTypeMarker for DynMapValueType {}
impl MapValueTypeMarker for DynMapValueType {}
impl ValueTypeMarker for DynMapValueType {
crate::private_impl!();
}
impl MapValueTypeMarker for DynMapValueType {
crate::private_impl!();
}
#[derive(Debug)]
pub struct MapValueType<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Debug>(PhantomData<(K, V)>);
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {
crate::private_impl!();
}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {
crate::private_impl!();
}
pub type DynMap = Value<DynMapValueType>;
pub type Map<K, V> = Value<MapValueType<K, V>>;
@@ -32,10 +46,7 @@ pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType<K, V>>;
pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType<K, V>>;
impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Clone>(
&self,
allocator: &Allocator
) -> Result<HashMap<K, V>> {
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
match self.dtype()? {
ValueType::Map { key, value } => {
let k_type = K::into_tensor_element_type();
@@ -47,47 +58,95 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
return Err(Error::InvalidMapValueType { expected: v_type, actual: value });
}
let allocator = Allocator::default();
let mut key_tensor_ptr = ptr::null_mut();
ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr) -> Error::ExtractMap; nonNull(key_tensor_ptr)];
let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) };
let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_tensor::<K>()?;
if K::into_tensor_element_type() != TensorElementType::String {
let dtype = key_value.dtype()?;
let (key_tensor_shape, key_tensor) = match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = key_value.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let mut value_tensor_ptr = ptr::null_mut();
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)];
let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) };
let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::<V>()?;
if ty == K::into_tensor_element_type() {
let mut output_array_ptr: *mut K = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
assert_eq!(key_tensor_shape.len(), 1);
assert_eq!(value_tensor_shape.len(), 1);
assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
let len = calculate_tensor_size(&dimensions);
(dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })
} else {
return Err(Error::DataTypeMismatch {
actual: ty,
requested: K::into_tensor_element_type()
});
}
}
_ => unreachable!()
};
let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
for i in 0..key_tensor_shape[0] as usize {
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
let mut value_tensor_ptr = ptr::null_mut();
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)];
let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) };
let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::<V>()?;
assert_eq!(key_tensor_shape.len(), 1);
assert_eq!(value_tensor_shape.len(), 1);
assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
for i in 0..key_tensor_shape[0] as usize {
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
}
Ok(vec.into_iter().collect())
} else {
let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_string_tensor()?;
// SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`,
// so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type
// checker.
let key_tensor: Vec<K> = unsafe { std::mem::transmute(key_tensor) };
let mut value_tensor_ptr = ptr::null_mut();
ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)];
let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) };
let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::<V>()?;
assert_eq!(key_tensor_shape.len(), 1);
assert_eq!(value_tensor_shape.len(), 1);
assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
for i in 0..key_tensor_shape[0] as usize {
vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
}
Ok(vec.into_iter().collect())
}
Ok(vec.into_iter().collect())
}
t => Err(Error::NotMap(t))
}
}
}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
impl<K: PrimitiveTensorElementType + Debug + Clone + Hash + Eq + 'static, V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
/// Creates a [`Map`] from an iterable emitting `K` and `V`.
///
/// ```
/// # use std::collections::HashMap;
/// # use ort::{Allocator, Map};
/// # use ort::Map;
/// # fn main() -> ort::Result<()> {
/// # let allocator = Allocator::default();
/// let mut map = HashMap::<i64, f32>::new();
/// map.insert(0, 1.0);
/// map.insert(1, 2.0);
/// map.insert(2, 3.0);
///
/// let value = Map::new(map)?;
/// let value = Map::<i64, f32>::new(map)?;
///
/// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0);
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
/// # Ok(())
/// # }
/// ```
@@ -95,20 +154,45 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTens
let (keys, values): (Vec<K>, Vec<V>) = data.into_iter().unzip();
Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
}
}
impl<V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<String, V>> {
/// Creates a [`Map`] from an iterable emitting `K` and `V`.
///
/// ```
/// # use std::collections::HashMap;
/// # use ort::Map;
/// # fn main() -> ort::Result<()> {
/// let mut map = HashMap::<i64, f32>::new();
/// map.insert(0, 1.0);
/// map.insert(1, 2.0);
/// map.insert(2, 3.0);
///
/// let value = Map::<i64, f32>::new(map)?;
///
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
/// # Ok(())
/// # }
/// ```
pub fn new(data: impl IntoIterator<Item = (String, V)>) -> Result<Self> {
let (keys, values): (Vec<String>, Vec<V>) = data.into_iter().unzip();
Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
}
}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
/// Creates a [`Map`] from two tensors of keys & values respectively.
///
/// ```
/// # use std::collections::HashMap;
/// # use ort::{Allocator, Map, Tensor};
/// # use ort::{Map, Tensor};
/// # fn main() -> ort::Result<()> {
/// # let allocator = Allocator::default();
/// let keys = Tensor::<i64>::from_array(([4], vec![0, 1, 2, 3]))?;
/// let values = Tensor::<f32>::from_array(([4], vec![1., 2., 3., 4.]))?;
///
/// let value = Map::new_kv(keys, values)?;
///
/// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0);
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
/// # Ok(())
/// # }
/// ```
@@ -122,21 +206,23 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTens
nonNull(value_ptr)
];
Ok(Value {
inner: ValueInner::RustOwned {
inner: Arc::new(ValueInner::RustOwned {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(values),
_memory_info: None
},
}),
_markers: PhantomData
})
}
}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
pub fn extract_map(&self, allocator: &Allocator) -> HashMap<K, V> {
self.try_extract_map(allocator).expect("Failed to extract map")
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: PrimitiveTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
pub fn extract_map(&self) -> HashMap<K, V> {
self.try_extract_map().expect("Failed to extract map")
}
}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
/// Converts from a strongly-typed [`Map<K, V>`] to a type-erased [`DynMap`].
#[inline]
pub fn upcast(self) -> DynMap {
@@ -149,7 +235,7 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
DynMapRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}
@@ -160,7 +246,7 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
DynMapRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}

View File

@@ -1,23 +1,34 @@
use std::{
fmt::Debug,
marker::PhantomData,
ptr::{self, NonNull}
ptr::{self, NonNull},
sync::Arc
};
use super::{DowncastableTarget, ValueInner, ValueTypeMarker};
use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType};
pub trait SequenceValueTypeMarker: ValueTypeMarker {}
pub trait SequenceValueTypeMarker: ValueTypeMarker {
crate::private_trait!();
}
#[derive(Debug)]
pub struct DynSequenceValueType;
impl ValueTypeMarker for DynSequenceValueType {}
impl SequenceValueTypeMarker for DynSequenceValueType {}
impl ValueTypeMarker for DynSequenceValueType {
crate::private_impl!();
}
impl SequenceValueTypeMarker for DynSequenceValueType {
crate::private_impl!();
}
#[derive(Debug)]
pub struct SequenceValueType<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(PhantomData<T>);
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {
crate::private_impl!();
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {
crate::private_impl!();
}
pub type DynSequence = Value<DynSequenceValueType>;
pub type Sequence<T> = Value<SequenceValueType<T>>;
@@ -89,11 +100,11 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<Se
nonNull(value_ptr)
];
Ok(Value {
inner: ValueInner::RustOwned {
inner: Arc::new(ValueInner::RustOwned {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(values),
_memory_info: None
},
}),
_markers: PhantomData
})
}
@@ -116,7 +127,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
DynSequenceRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}
@@ -127,7 +138,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
DynSequenceRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}

View File

@@ -15,15 +15,15 @@ use crate::{
error::assert_non_null_pointer,
memory::{Allocator, MemoryInfo},
ortsys,
tensor::{IntoTensorElementType, TensorElementType, Utf8Data},
value::ValueInner,
AllocatorType, DynValue, Error, MemoryType, Result, TensorRefMut, Value
tensor::{TensorElementType, Utf8Data},
value::{impl_tensor::calculate_tensor_size, ValueInner},
AllocationDevice, AllocatorType, DynValue, Error, MemoryType, PrimitiveTensorElementType, Result, TensorRefMut, Value
};
impl DynTensor {
/// Construct a [`Value`] from an array of strings.
impl Tensor<String> {
/// Construct a [`DynTensor`] from an array of strings.
///
/// Just like numeric tensors, string tensor `Value`s can be created from:
/// Just like numeric tensors, string tensors can be created from:
/// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`);
/// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray<T, D>`);
/// - (with feature `ndarray`) an owned [`ndarray::Array`];
@@ -36,26 +36,19 @@ impl DynTensor {
/// ```
/// # use ort::{Session, Value};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?;
/// // You'll need to obtain an `Allocator` from a session in order to create string tensors.
/// let allocator = session.allocator();
///
/// // Create a string tensor from a raw data vector
/// let data = vec!["hello", "world"];
/// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?;
/// let value = Value::from_string_array(([data.len()], data.into_boxed_slice()))?;
///
/// // Create a string tensor from an `ndarray::Array`
/// #[cfg(feature = "ndarray")]
/// let value = Value::from_string_array(
/// allocator,
/// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()
/// )?;
/// let value = Value::from_string_array(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?;
/// # Ok(())
/// # }
/// ```
///
/// Note that string data will *always* be copied, no matter what form the data is provided in.
pub fn from_string_array<T: Utf8Data>(allocator: &Allocator, input: impl IntoValueTensor<Item = T>) -> Result<DynTensor> {
pub fn from_string_array<T: Utf8Data>(input: impl IntoValueTensor<Item = T>) -> Result<Tensor<String>> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
let (shape, data) = input.ref_parts()?;
@@ -64,7 +57,7 @@ impl DynTensor {
// create tensor without data -- data is filled in later
ortsys![
unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr)
unsafe CreateTensorAsOrtValue(Allocator::default().ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr)
-> Error::CreateTensor;
nonNull(value_ptr)
];
@@ -84,18 +77,18 @@ impl DynTensor {
ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor];
Ok(Value {
inner: ValueInner::RustOwned {
inner: Arc::new(ValueInner::RustOwned {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(()),
_memory_info: None
},
}),
_markers: PhantomData
})
}
}
impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// Construct a tensor [`Value`] in a given allocator with a given shape and datatype. The data contained in the
impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// Construct a tensor in a given allocator with a given shape and datatype. The data contained in the
/// value will be zero-allocated on the allocation device.
///
/// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned
@@ -132,18 +125,18 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
];
Ok(Value {
inner: ValueInner::RustOwned {
inner: Arc::new(ValueInner::RustOwned {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(()),
_memory_info: None
},
}),
_markers: PhantomData
})
}
/// Construct a tensor [`Value`] from an array of data.
/// Construct a tensor from an array of data.
///
/// Tensor `Value`s can be created from:
/// Tensors can be created from:
/// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`);
/// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray<T, D>`);
/// - (with feature `ndarray`) an owned [`ndarray::Array`];
@@ -154,19 +147,19 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// * and `data` is one of `Vec<T>`, `Box<[T]>`, `Arc<Box<[T]>>`, or `&[T]`.
///
/// ```
/// # use ort::Value;
/// # use ort::Tensor;
/// # fn main() -> ort::Result<()> {
/// // Create a tensor from a raw data vector
/// let value = Value::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?;
/// let tensor = Tensor::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?;
///
/// // Create a tensor from an `ndarray::Array`
/// #[cfg(feature = "ndarray")]
/// let value = Value::from_array(ndarray::Array4::<f32>::zeros((1, 16, 16, 3)))?;
/// let tensor = Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 16, 16, 3)))?;
/// # Ok(())
/// # }
/// ```
///
/// Creating string tensors requires a separate method; see [`Value::from_string_array`].
/// Creating string tensors requires a separate method; see [`DynTensor::from_string_array`].
///
/// Note that data provided in an `ndarray` may be copied in some circumstances:
/// - `&CowArray<'_, T, D>` will always be copied regardless of whether it is uniquely owned or borrowed.
@@ -177,7 +170,7 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// Raw data provided as a `Arc<Box<[T]>>`, `Box<[T]>`, or `Vec<T>` will never be copied. Raw data is expected to be
/// in standard, contigous layout.
pub fn from_array(input: impl IntoValueTensor<Item = T>) -> Result<Tensor<T>> {
let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?;
let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::Default)?;
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
@@ -203,17 +196,17 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
];
Ok(Value {
inner: ValueInner::RustOwned {
inner: Arc::new(ValueInner::RustOwned {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: guard,
_memory_info: Some(memory_info)
},
}),
_markers: PhantomData
})
}
}
impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> {
impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
/// Create a mutable tensor view from a raw pointer and shape.
///
/// The length of data is determined by `T` and the given shape, so the given buffer must be at least
@@ -260,11 +253,11 @@ impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> {
];
Ok(TensorRefMut::new(Value {
inner: ValueInner::CppOwned {
inner: Arc::new(ValueInner::CppOwned {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
drop: true,
_session: None
},
}),
_markers: PhantomData
}))
}
@@ -290,7 +283,7 @@ macro_rules! impl_to_dimensions {
.enumerate()
.map(|(i, c)| if *c >= 1 { Ok(*c as i64) } else { Err(Error::InvalidDimension(i)) })
.collect::<Result<_>>()?;
let sum = v.iter().product::<i64>() as usize;
let sum = calculate_tensor_size(&v);
if let Some(expected_size) = expected_size {
if sum != expected_size {
Err(Error::TensorShapeMismatch {
@@ -318,6 +311,14 @@ macro_rules! impl_to_dimensions {
};
}
impl ToDimensions for () {
fn to_dimensions(&self, expected_size: Option<usize>) -> Result<Vec<i64>> {
match expected_size {
Some(1) | None => Ok(vec![]),
Some(x) => Err(Error::TensorShapeMismatch { input: vec![], total: 1, expected: x })
}
}
}
impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec<usize>, for Vec<i32>, for Vec<i64>);
impl_to_dimensions!(<N> for [usize; N], for [i32; N], for [i64; N]);
@@ -500,7 +501,7 @@ impl<T: Clone + Debug + 'static, D: ToDimensions> IntoValueTensor for (D, Arc<Bo
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor<T>
impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor<T>
where
'i: 'v
{
@@ -512,7 +513,7 @@ where
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for Tensor<T> {
impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for Tensor<T> {
type Error = Error;
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
Tensor::from_array(arr)
@@ -521,7 +522,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor
impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor
where
'i: 'v
{
@@ -533,7 +534,7 @@ where
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynTensor {
impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynTensor {
type Error = Error;
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
Tensor::from_array(arr).map(|c| c.upcast())
@@ -542,7 +543,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue
impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue
where
'i: 'v
{
@@ -554,7 +555,7 @@ where
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynValue {
impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynValue {
type Error = Error;
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
Tensor::from_array(arr).map(|c| c.into_dyn())
@@ -564,19 +565,19 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta
macro_rules! impl_try_from {
(@T,I $($t:ty),+) => {
$(
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for Tensor<T> {
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for Tensor<T> {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value)
}
}
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for DynTensor {
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for DynTensor {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value).map(|c| c.upcast())
}
}
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for crate::DynValue {
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for crate::DynValue {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value).map(|c| c.into_dyn())
@@ -587,21 +588,21 @@ macro_rules! impl_try_from {
(@T,D $($t:ty),+) => {
$(
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for Tensor<T> {
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for Tensor<T> {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value)
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for DynTensor {
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for DynTensor {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value).map(|c| c.upcast())
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for crate::DynValue {
impl<T: PrimitiveTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for crate::DynValue {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value).map(|c| c.into_dyn())

View File

@@ -1,4 +1,4 @@
use std::{fmt::Debug, os::raw::c_char, ptr, string::FromUtf8Error};
use std::{fmt::Debug, ptr, string::FromUtf8Error};
#[cfg(feature = "ndarray")]
use ndarray::IxDyn;
@@ -7,9 +7,7 @@ use super::TensorValueTypeMarker;
#[cfg(feature = "ndarray")]
use crate::tensor::{extract_primitive_array, extract_primitive_array_mut};
use crate::{
ortsys,
tensor::{IntoTensorElementType, TensorElementType},
Error, Result, Tensor, Value
ortsys, tensor::TensorElementType, value::impl_tensor::calculate_tensor_size, Error, PrimitiveTensorElementType, Result, Tensor, Value, ValueType
};
impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
@@ -38,38 +36,81 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
/// infallible [`Tensor::extract_tensor`] instead)*
/// - The provided type `T` does not match the tensor's element type.
/// - The tensor's data is not allocated in CPU memory.
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn try_extract_tensor<T: IntoTensorElementType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
pub fn try_extract_tensor<T: PrimitiveTensorElementType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let res = {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let data_type: TensorElementType = type_sys.into();
if data_type == T::into_tensor_element_type() {
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::<Vec<_>>());
let mut len = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
Ok(extract_primitive_array(shape, self.ptr())?)
} else {
Err(Error::DataTypeMismatch {
actual: data_type,
requested: T::into_tensor_element_type()
})
if ty == T::into_tensor_element_type() {
Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>()), self.ptr())?)
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: T::into_tensor_element_type()
})
}
}
};
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
t => Err(Error::NotTensor(t))
}
}
/// Attempt to extract the scalar from a tensor of type `T`.
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{Session, Value};
/// # fn main() -> ort::Result<()> {
/// let value = Value::from_array(((), vec![3.14_f32]))?;
///
/// let extracted = value.try_extract_scalar::<f32>()?;
/// assert_eq!(extracted, 3.14);
/// # Ok(())
/// # }
/// ```
///
/// # Errors
/// May return an error if:
/// - The tensor is not 0-dimensional.
/// - The provided type `T` does not match the tensor's element type.
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
/// infallible [`Tensor::extract_tensor`] instead)*
/// - The tensor's data is not allocated in CPU memory.
pub fn try_extract_scalar<T: PrimitiveTensorElementType + Copy>(&self) -> Result<T> {
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
if !dimensions.is_empty() {
return Err(Error::TensorNot0Dimensional(dimensions.len()));
}
if ty == T::into_tensor_element_type() {
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
Ok(unsafe { *output_array_ptr })
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: T::into_tensor_element_type()
})
}
}
t => Err(Error::NotTensor(t))
}
}
/// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`].
@@ -101,36 +142,26 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// - The provided type `T` does not match the tensor's element type.
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn try_extract_tensor_mut<T: IntoTensorElementType>(&mut self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
pub fn try_extract_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let res = {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let data_type: TensorElementType = type_sys.into();
if data_type == T::into_tensor_element_type() {
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::<Vec<_>>());
let mut len = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
Ok(extract_primitive_array_mut(shape, self.ptr())?)
} else {
Err(Error::DataTypeMismatch {
actual: data_type,
requested: T::into_tensor_element_type()
})
if ty == T::into_tensor_element_type() {
Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>()), self.ptr())?)
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: T::into_tensor_element_type()
})
}
}
};
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
t => Err(Error::NotTensor(t))
}
}
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an
@@ -159,40 +190,32 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
/// infallible [`Tensor::extract_raw_tensor`] instead)*
/// - The provided type `T` does not match the tensor's element type.
pub fn try_extract_raw_tensor<T: IntoTensorElementType>(&self) -> Result<(Vec<i64>, &[T])> {
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
pub fn try_extract_raw_tensor<T: PrimitiveTensorElementType>(&self) -> Result<(Vec<i64>, &[T])> {
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let res = {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let data_type: TensorElementType = type_sys.into();
if data_type == T::into_tensor_element_type() {
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
if ty == T::into_tensor_element_type() {
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
let mut len = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
Ok((node_dims, unsafe { std::slice::from_raw_parts(output_array_ptr, len as _) }))
} else {
Err(Error::DataTypeMismatch {
actual: data_type,
requested: T::into_tensor_element_type()
})
let len = calculate_tensor_size(&dimensions);
Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }))
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: T::into_tensor_element_type()
})
}
}
};
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
t => Err(Error::NotTensor(t))
}
}
/// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a
@@ -218,50 +241,41 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the
/// infallible [`Tensor::extract_raw_tensor_mut`] instead)*
/// - The provided type `T` does not match the tensor's element type.
pub fn try_extract_raw_tensor_mut<T: IntoTensorElementType>(&mut self) -> Result<(Vec<i64>, &mut [T])> {
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
pub fn try_extract_raw_tensor_mut<T: PrimitiveTensorElementType>(&mut self) -> Result<(Vec<i64>, &mut [T])> {
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let res = {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let data_type: TensorElementType = type_sys.into();
if data_type == T::into_tensor_element_type() {
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
if ty == T::into_tensor_element_type() {
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
let mut len = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
Ok((node_dims, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len as _) }))
} else {
Err(Error::DataTypeMismatch {
actual: data_type,
requested: T::into_tensor_element_type()
})
let len = calculate_tensor_size(&dimensions);
Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) }))
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: T::into_tensor_element_type()
})
}
}
};
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
t => Err(Error::NotTensor(t))
}
}
/// Attempt to extract the underlying data into a Rust `ndarray`.
///
/// ```
/// # use ort::{Allocator, Session, DynTensor, TensorElementType};
/// # use ort::{Session, Tensor, TensorElementType};
/// # fn main() -> ort::Result<()> {
/// # let allocator = Allocator::default();
/// let array = ndarray::Array1::from_vec(vec!["hello", "world"]);
/// let tensor = DynTensor::from_string_array(&allocator, array.clone())?;
/// let tensor = Tensor::from_string_array(array.clone())?;
///
/// let extracted = tensor.try_extract_string_tensor()?;
/// assert_eq!(array.into_dyn(), extracted);
@@ -271,78 +285,68 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub fn try_extract_string_tensor(&self) -> Result<ndarray::ArrayD<String>> {
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let res = {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let data_type: TensorElementType = type_sys.into();
if data_type == TensorElementType::String {
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
if ty == TensorElementType::String {
let len = calculate_tensor_size(&dimensions);
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::<Vec<_>>());
// Total length of string data, not including \0 suffix
let mut total_length: ort_sys::size_t = 0;
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
let mut len: ort_sys::size_t = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
// don't seem to be written to in practice either.
// If the string data actually did go farther, it would panic below when using the offset
// data to get slices for each string.
let mut string_contents = vec![0u8; total_length as _];
// one extra slot so that the total length can go in the last one, making all per-string
// length calculations easy
let mut offsets = vec![0; (len + 1) as _];
// Total length of string data, not including \0 suffix
let mut total_length: ort_sys::size_t = 0;
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent];
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
// don't seem to be written to in practice either.
// If the string data actually did go farther, it would panic below when using the offset
// data to get slices for each string.
let mut string_contents = vec![0u8; total_length as _];
// one extra slot so that the total length can go in the last one, making all per-string
// length calculations easy
let mut offsets = vec![0; (len + 1) as _];
// final offset = overall length so that per-string length calculations work for the last string
debug_assert_eq!(0, offsets[len]);
offsets[len] = total_length;
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent];
let strings = offsets
// offsets has 1 extra offset past the end so that all windows work
.windows(2)
.map(|w| {
let slice = &string_contents[w[0] as _..w[1] as _];
String::from_utf8(slice.into())
})
.collect::<Result<Vec<String>, FromUtf8Error>>()
.map_err(Error::StringFromUtf8Error)?;
// final offset = overall length so that per-string length calculations work for the last string
debug_assert_eq!(0, offsets[len as usize]);
offsets[len as usize] = total_length;
let strings = offsets
// offsets has 1 extra offset past the end so that all windows work
.windows(2)
.map(|w| {
let slice = &string_contents[w[0] as _..w[1] as _];
String::from_utf8(slice.into())
Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::<Vec<_>>()), strings)
.expect("Shape extracted from tensor didn't match tensor contents"))
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: TensorElementType::String
})
.collect::<Result<Vec<String>, FromUtf8Error>>()
.map_err(Error::StringFromUtf8Error)?;
Ok(ndarray::Array::from_shape_vec(shape, strings)
.expect("Shape extracted from tensor didn't match tensor contents")
.into_dyn())
} else {
Err(Error::DataTypeMismatch {
actual: data_type,
requested: TensorElementType::String
})
}
}
};
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
t => Err(Error::NotTensor(t))
}
}
/// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and
/// an owned `Vec` of its data.
///
/// ```
/// # use ort::{Allocator, Session, DynTensor, TensorElementType};
/// # use ort::{Session, Tensor, TensorElementType};
/// # fn main() -> ort::Result<()> {
/// # let allocator = Allocator::default();
/// let array = vec!["hello", "world"];
/// let tensor = DynTensor::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?;
/// let tensor = Tensor::from_string_array(([array.len()], array.clone().into_boxed_slice()))?;
///
/// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?;
/// assert_eq!(extracted_data, array);
@@ -351,68 +355,57 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// # }
/// ```
pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec<i64>, Vec<String>)> {
let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape];
let dtype = self.dtype()?;
match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = self.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::TensorNotOnCpu(device.as_str()));
}
let res = {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let data_type: TensorElementType = type_sys.into();
if data_type == TensorElementType::String {
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount];
if ty == TensorElementType::String {
let len = calculate_tensor_size(&dimensions);
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];
// Total length of string data, not including \0 suffix
let mut total_length: ort_sys::size_t = 0;
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
let mut output_array_ptr: *mut c_char = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut c_char = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)];
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
// don't seem to be written to in practice either.
// If the string data actually did go farther, it would panic below when using the offset
// data to get slices for each string.
let mut string_contents = vec![0u8; total_length as _];
// one extra slot so that the total length can go in the last one, making all per-string
// length calculations easy
let mut offsets = vec![0; (len + 1) as _];
let mut len: ort_sys::size_t = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount];
// Total length of string data, not including \0 suffix
let mut total_length = 0;
ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength];
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent];
// In the JNI impl of this, tensor_element_len was included in addition to total_length,
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
// don't seem to be written to in practice either.
// If the string data actually did go farther, it would panic below when using the offset
// data to get slices for each string.
let mut string_contents = vec![0u8; total_length as _];
// one extra slot so that the total length can go in the last one, making all per-string
// length calculations easy
let mut offsets = vec![0; len as usize + 1];
// final offset = overall length so that per-string length calculations work for the last string
debug_assert_eq!(0, offsets[len]);
offsets[len] = total_length;
ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length as _, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent];
let strings = offsets
// offsets has 1 extra offset past the end so that all windows work
.windows(2)
.map(|w| {
let slice = &string_contents[w[0] as _..w[1] as _];
String::from_utf8(slice.into())
})
.collect::<Result<Vec<String>, FromUtf8Error>>()
.map_err(Error::StringFromUtf8Error)?;
// final offset = overall length so that per-string length calculations work for the last string
debug_assert_eq!(0, offsets[len as usize]);
offsets[len as usize] = total_length;
let strings = offsets
// offsets has 1 extra offset past the end so that all windows work
.windows(2)
.map(|w| {
let slice = &string_contents[w[0] as _..w[1] as _];
String::from_utf8(slice.into())
Ok((dimensions, strings))
} else {
Err(Error::DataTypeMismatch {
actual: ty,
requested: TensorElementType::String
})
.collect::<Result<Vec<String>, FromUtf8Error>>()
.map_err(Error::StringFromUtf8Error)?;
Ok((node_dims, strings))
} else {
Err(Error::DataTypeMismatch {
actual: data_type,
requested: TensorElementType::String
})
}
}
};
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
t => Err(Error::NotTensor(t))
}
}
/// Returns the shape of the tensor.
@@ -445,7 +438,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
}
}
impl<T: IntoTensorElementType + Debug> Tensor<T> {
impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// Extracts the underlying data into a read-only [`ndarray::ArrayView`].
///
/// ```

View File

@@ -11,41 +11,96 @@ use std::{
use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker};
use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType};
pub trait TensorValueTypeMarker: ValueTypeMarker {}
pub trait TensorValueTypeMarker: ValueTypeMarker {
crate::private_trait!();
}
#[derive(Debug)]
pub struct DynTensorValueType;
impl ValueTypeMarker for DynTensorValueType {}
impl TensorValueTypeMarker for DynTensorValueType {}
impl ValueTypeMarker for DynTensorValueType {
crate::private_impl!();
}
impl TensorValueTypeMarker for DynTensorValueType {
crate::private_impl!();
}
#[derive(Debug)]
pub struct TensorValueType<T: IntoTensorElementType + Debug>(PhantomData<T>);
impl<T: IntoTensorElementType + Debug> ValueTypeMarker for TensorValueType<T> {}
impl<T: IntoTensorElementType + Debug> TensorValueTypeMarker for TensorValueType<T> {}
impl<T: IntoTensorElementType + Debug> ValueTypeMarker for TensorValueType<T> {
crate::private_impl!();
}
impl<T: IntoTensorElementType + Debug> TensorValueTypeMarker for TensorValueType<T> {
crate::private_impl!();
}
/// A tensor [`Value`] whose data type is unknown.
pub type DynTensor = Value<DynTensorValueType>;
/// A strongly-typed tensor [`Value`].
pub type Tensor<T> = Value<TensorValueType<T>>;
/// A reference to a tensor [`Value`] whose data type is unknown.
pub type DynTensorRef<'v> = ValueRef<'v, DynTensorValueType>;
/// A mutable reference to a tensor [`Value`] whose data type is unknown.
pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>;
/// A reference to a strongly-typed tensor [`Value`].
pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType<T>>;
/// A mutable reference to a strongly-typed tensor [`Value`].
pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType<T>>;
impl DowncastableTarget for DynTensorValueType {
fn can_downcast(dtype: &ValueType) -> bool {
matches!(dtype, ValueType::Tensor { .. })
}
crate::private_impl!();
}
impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// Returns a mutable pointer to the tensor's data.
///
/// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a
/// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be
/// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access.
/// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before
/// accessing it.
///
/// ```
/// # use ort::{Allocator, Tensor};
/// # fn main() -> ort::Result<()> {
/// let mut tensor = Tensor::<i64>::from_array((vec![5], vec![0, 1, 2, 3, 4]))?;
/// let ptr = tensor.data_ptr_mut()?.cast::<i64>();
/// unsafe {
/// *ptr.add(3) = 42;
/// };
///
/// let (_, extracted) = tensor.extract_raw_tensor();
/// assert_eq!(&extracted, &[0, 1, 2, 42, 4]);
/// # Ok(())
/// # }
/// ```
pub fn data_ptr_mut(&mut self) -> Result<*mut ort_sys::c_void> {
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)];
Ok(buffer_ptr)
}
/// Returns a pointer to the tensor's data.
/// Returns an immutable pointer to the tensor's underlying data.
///
/// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a
/// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be
/// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access.
/// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before
/// accessing it.
///
/// ```
/// # use ort::{Allocator, Tensor};
/// # fn main() -> ort::Result<()> {
/// let tensor = Tensor::<i64>::from_array((vec![5], vec![0, 1, 2, 3, 4]))?;
/// let ptr = tensor.data_ptr()?.cast::<i64>();
/// assert_eq!(unsafe { *ptr.add(3) }, 3);
/// # Ok(())
/// # }
/// ```
pub fn data_ptr(&self) -> Result<*const ort_sys::c_void> {
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)];
@@ -53,6 +108,26 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
}
/// Returns information about the device this tensor is allocated on.
///
/// ```
/// # use ort::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType, Session, Tensor};
/// # fn main() -> ort::Result<()> {
/// let tensor = Tensor::<f32>::new(&Allocator::default(), [1, 3, 224, 224])?;
/// // Tensors are allocated on CPU by default.
/// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CPU);
///
/// # if false {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let cuda_allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?
/// )?;
/// let tensor = Tensor::<f32>::new(&cuda_allocator, [1, 3, 224, 224])?;
/// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CUDA);
/// # }
/// # Ok(())
/// # }
/// ```
pub fn memory_info(&self) -> Result<MemoryInfo> {
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut();
ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr) -> Error::GetTensorMemoryInfo; nonNull(memory_info_ptr)];
@@ -62,29 +137,68 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// Converts from a strongly-typed [`Tensor<T>`] to a type-erased [`DynTensor`].
///
/// ```
/// # use ort::{Allocator, DynTensor, Tensor};
/// # fn main() -> ort::Result<()> {
/// let tensor = Tensor::<f32>::new(&Allocator::default(), [1, 3, 224, 224])?;
/// let tensor_dyn = tensor.upcast();
/// assert!(tensor_dyn.try_extract_raw_tensor::<f32>().is_ok());
/// assert!(tensor_dyn.try_extract_raw_tensor::<i64>().is_err());
/// # Ok(())
/// # }
/// ```
#[inline]
pub fn upcast(self) -> DynTensor {
unsafe { std::mem::transmute(self) }
}
/// Converts from a strongly-typed [`Tensor<T>`] to a reference to a type-erased [`DynTensor`].
/// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor<T>`].
///
/// ```
/// # use ort::{Allocator, DynTensor, Tensor};
/// # fn main() -> ort::Result<()> {
/// let tensor = Tensor::<f32>::new(&Allocator::default(), [1, 3, 224, 224])?;
/// let tensor_dyn = tensor.upcast_ref();
///
/// let (_, original_extract) = tensor.extract_raw_tensor();
/// let (_, ref_extract) = tensor_dyn.try_extract_raw_tensor::<f32>()?;
/// assert_eq!(original_extract, ref_extract);
/// # Ok(())
/// # }
/// ```
#[inline]
pub fn upcast_ref(&self) -> DynTensorRef {
DynTensorRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}
/// Converts from a strongly-typed [`Tensor<T>`] to a mutable reference to a type-erased [`DynTensor`].
///
/// ```
/// # use ort::{Allocator, DynTensor, Tensor};
/// # fn main() -> ort::Result<()> {
/// let mut tensor = Tensor::<i64>::from_array((vec![5], vec![1, 2, 3, 4, 5]))?;
/// let mut tensor_dyn = tensor.upcast_mut();
///
/// let (_, mut_view) = tensor_dyn.try_extract_raw_tensor_mut::<i64>()?;
/// mut_view[3] = 0;
///
/// let (_, original_view) = tensor.extract_raw_tensor();
/// assert_eq!(original_view, &[1, 2, 3, 0, 5]);
/// # Ok(())
/// # }
/// ```
#[inline]
pub fn upcast_mut(&mut self) -> DynTensorRefMut {
DynTensorRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}
@@ -97,6 +211,8 @@ impl<T: IntoTensorElementType + Debug> DowncastableTarget for TensorValueType<T>
_ => false
}
}
crate::private_impl!();
}
impl<T: IntoTensorElementType + Debug> From<Value<TensorValueType<T>>> for DynValue {
@@ -113,6 +229,17 @@ impl From<Value<DynTensorValueType>> for DynValue {
impl<T: IntoTensorElementType + Clone + Debug, const N: usize> Index<[i64; N]> for Tensor<T> {
type Output = T;
fn index(&self, index: [i64; N]) -> &Self::Output {
// Interestingly, the `TensorAt` API doesn't check if the tensor is on CPU, so we have to perform the check ourselves.
if !self
.memory_info()
.expect("could not retrieve tensor memory info")
.allocation_device()
.expect("could not retrieve tensor allocation device")
.is_cpu_accessible()
{
panic!("Cannot directly index a tensor which is not allocated on the CPU.");
}
let mut out: *mut ort_sys::c_void = std::ptr::null_mut();
ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")];
unsafe { &*out.cast::<T>() }
@@ -120,25 +247,46 @@ impl<T: IntoTensorElementType + Clone + Debug, const N: usize> Index<[i64; N]> f
}
impl<T: IntoTensorElementType + Clone + Debug, const N: usize> IndexMut<[i64; N]> for Tensor<T> {
fn index_mut(&mut self, index: [i64; N]) -> &mut Self::Output {
if !self
.memory_info()
.expect("could not retrieve tensor memory info")
.allocation_device()
.expect("could not retrieve tensor allocation device")
.is_cpu_accessible()
{
panic!("Cannot directly index a tensor which is not allocated on the CPU.");
}
let mut out: *mut ort_sys::c_void = std::ptr::null_mut();
ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")];
unsafe { &mut *out.cast::<T>() }
}
}
pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize {
let mut size = 1usize;
for dim in shape {
if *dim < 0 {
return 0;
}
size *= *dim as usize;
}
size
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use ndarray::{ArcArray1, Array1, CowArray};
use crate::{Allocator, DynTensor, TensorElementType, Value, ValueType};
use crate::{Tensor, TensorElementType, ValueType};
#[test]
#[cfg(feature = "ndarray")]
fn test_tensor_value() -> crate::Result<()> {
let v: Vec<f32> = vec![1., 2., 3., 4., 5.];
let value = Value::from_array(Array1::from_vec(v.clone()))?;
let value = Tensor::from_array(Array1::from_vec(v.clone()))?;
assert!(value.is_tensor()?);
assert_eq!(value.dtype()?.tensor_type(), Some(TensorElementType::Float32));
assert_eq!(
@@ -163,17 +311,17 @@ mod tests {
let arc1 = ArcArray1::from_vec(v.clone());
let mut arc2 = ArcArray1::clone(&arc1);
let value = Value::from_array(&mut arc2)?;
let value = Tensor::from_array(&mut arc2)?;
drop((arc1, arc2));
assert_eq!(value.extract_raw_tensor().1, &v);
let cow = CowArray::from(Array1::from_vec(v.clone()));
let value = Value::from_array(&cow)?;
let value = Tensor::from_array(&cow)?;
assert_eq!(value.extract_raw_tensor().1, &v);
let owned = Array1::from_vec(v.clone());
let value = Value::from_array(owned.view())?;
let value = Tensor::from_array(owned.view())?;
drop(owned);
assert_eq!(value.extract_raw_tensor().1, &v);
@@ -186,7 +334,7 @@ mod tests {
let arc = Arc::new(v.clone().into_boxed_slice());
let shape = vec![v.len() as i64];
let value = Value::from_array((shape, Arc::clone(&arc)))?;
let value = Tensor::from_array((shape, Arc::clone(&arc)))?;
drop(arc);
assert_eq!(value.try_extract_raw_tensor::<f32>()?.1, &v);
@@ -196,10 +344,9 @@ mod tests {
#[test]
#[cfg(feature = "ndarray")]
fn test_string_tensor_ndarray() -> crate::Result<()> {
let allocator = Allocator::default();
let v = Array1::from_vec(vec!["hello world".to_string(), "こんにちは世界".to_string()]);
let value = DynTensor::from_string_array(&allocator, v.view())?;
let value = Tensor::from_string_array(v.view())?;
let extracted = value.try_extract_string_tensor()?;
assert_eq!(extracted, v.into_dyn());
@@ -208,10 +355,9 @@ mod tests {
#[test]
fn test_string_tensor_raw() -> crate::Result<()> {
let allocator = Allocator::default();
let v = vec!["hello world".to_string(), "こんにちは世界".to_string()];
let value = DynTensor::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?;
let value = Tensor::from_string_array((vec![v.len() as i64], v.clone().into_boxed_slice()))?;
let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?;
assert_eq!(extracted_shape, [v.len() as i64]);
assert_eq!(extracted_view, v);
@@ -224,10 +370,10 @@ mod tests {
let v: Vec<f32> = vec![1., 2., 3., 4., 5.];
let shape = [v.len()];
let value_arc_box = Value::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?;
let value_box = Value::from_array((shape, v.clone().into_boxed_slice()))?;
let value_vec = Value::from_array((shape, v.clone()))?;
let value_slice = Value::from_array((shape, &v[..]))?;
let value_arc_box = Tensor::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?;
let value_box = Tensor::from_array((shape, v.clone().into_boxed_slice()))?;
let value_vec = Tensor::from_array((shape, v.clone()))?;
let value_slice = Tensor::from_array((shape, &v[..]))?;
assert_eq!(value_arc_box.extract_raw_tensor().1, &v);
assert_eq!(value_box.extract_raw_tensor().1, &v);

View File

@@ -166,6 +166,14 @@ pub(crate) enum ValueInner {
}
}
impl ValueInner {
pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue {
match self {
ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr()
}
}
}
/// A temporary version of a [`Value`] with a lifetime specifier.
#[derive(Debug)]
pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
@@ -277,8 +285,8 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> {
/// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`]
#[derive(Debug)]
pub struct Value<Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
inner: ValueInner,
_markers: PhantomData<Type>
pub(crate) inner: Arc<ValueInner>,
pub(crate) _markers: PhantomData<Type>
}
/// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`].
@@ -291,11 +299,15 @@ pub type DynValue = Value<DynValueTypeMarker>;
///
/// For example, [`Tensor::try_extract_tensor`] can only be used on [`Value`]s with the [`TensorValueTypeMarker`] (which
/// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s.
pub trait ValueTypeMarker: Debug {}
pub trait ValueTypeMarker: Debug {
crate::private_trait!();
}
/// Represents a type that a [`DynValue`] can be downcast to.
pub trait DowncastableTarget: ValueTypeMarker {
fn can_downcast(dtype: &ValueType) -> bool;
crate::private_trait!();
}
// this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence`
@@ -303,15 +315,25 @@ impl DowncastableTarget for DynValueTypeMarker {
fn can_downcast(_: &ValueType) -> bool {
true
}
crate::private_impl!();
}
/// The dynamic type marker, used for values which can be of any type.
#[derive(Debug)]
pub struct DynValueTypeMarker;
impl ValueTypeMarker for DynValueTypeMarker {}
impl MapValueTypeMarker for DynValueTypeMarker {}
impl SequenceValueTypeMarker for DynValueTypeMarker {}
impl TensorValueTypeMarker for DynValueTypeMarker {}
impl ValueTypeMarker for DynValueTypeMarker {
crate::private_impl!();
}
impl MapValueTypeMarker for DynValueTypeMarker {
crate::private_impl!();
}
impl SequenceValueTypeMarker for DynValueTypeMarker {
crate::private_impl!();
}
impl TensorValueTypeMarker for DynValueTypeMarker {
crate::private_impl!();
}
unsafe impl Send for Value {}
@@ -350,7 +372,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
///
/// If the value belongs to a session (i.e. if it is returned from [`crate::Session::run`] or
/// [`crate::IoBinding::run`]), you must provide the [`SharedSessionInner`] (acquired from
/// [`crate::Session::inner`]). This ensures the session is not dropped until the value is.
/// [`crate::Session::inner`]). This ensures the session is not dropped until any values owned by it is.
///
/// # Safety
///
@@ -359,7 +381,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
#[must_use]
pub unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
Value {
inner: ValueInner::CppOwned { ptr, drop: true, _session: session },
inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }),
_markers: PhantomData
}
}
@@ -369,16 +391,14 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
#[must_use]
pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
Value {
inner: ValueInner::CppOwned { ptr, drop: false, _session: session },
inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }),
_markers: PhantomData
}
}
/// Returns the underlying [`ort_sys::OrtValue`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtValue {
match &self.inner {
ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr()
}
self.inner.ptr()
}
/// Create a view of this value's data.
@@ -386,7 +406,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
ValueRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}
@@ -396,7 +416,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
ValueRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
})
}
@@ -442,7 +462,7 @@ impl Value<DynValueTypeMarker> {
Ok(ValueRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
}))
} else {
@@ -450,7 +470,7 @@ impl Value<DynValueTypeMarker> {
}
}
/// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed
/// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed
/// mutable-reference variant, like [`TensorRefMut<T>`].
#[inline]
pub fn downcast_mut<OtherType: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(&mut self) -> Result<ValueRefMut<'_, OtherType>> {
@@ -459,7 +479,7 @@ impl Value<DynValueTypeMarker> {
Ok(ValueRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None }
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
}))
} else {
@@ -468,17 +488,17 @@ impl Value<DynValueTypeMarker> {
}
}
impl<Type: ValueTypeMarker + ?Sized> Drop for Value<Type> {
impl Drop for ValueInner {
fn drop(&mut self) {
let ptr = self.ptr();
tracing::trace!(
"dropping {} value at {ptr:p}",
match &self.inner {
match self {
ValueInner::RustOwned { .. } => "rust-owned",
ValueInner::CppOwned { .. } => "cpp-owned"
}
);
if !matches!(&self.inner, ValueInner::CppOwned { drop: false, .. }) {
if !matches!(self, ValueInner::CppOwned { drop: false, .. }) {
ortsys![unsafe ReleaseValue(ptr)];
}
}

View File

@@ -1,6 +1,6 @@
//! Utilities for using `ort` in WebAssembly.
//!
//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs:
//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs in WASM:
//! ```
//! # use ort::Session;
//! # static MODEL_BYTES: &[u8] = include_bytes!("../tests/data/upsample.ort");
@@ -223,12 +223,12 @@ mod emscripten_shims {
#[no_mangle]
#[export_name = "_initialize"]
pub fn initialize() {
// No idea what the hell this does, but the presence of an `_initialize` function prevents the linker from calling
// `__wasm_call_ctors` at the top of every function - including the functions `wasm-bindgen` interprets to generate
// JS glue code. The `__wasm_call_ctors` call was calling complex functions that the interpreter isn't equipped to
// handle, which was preventing wbg from outputting anything. I don't know what specific constructors this is calling,
// and most basic ONNX Runtime APIs *do* work without calling this, but we encourage the user to perform this
// initialization at program start anyways to be safe.
// The presence of an `_initialize` function prevents the linker from calling `__wasm_call_ctors` at the top of every
// function - including the functions `wasm-bindgen` interprets to generate JS glue code. `__wasm_call_ctors` calls
// complex functions that wbg's interpreter isn't equipped to handle, which was preventing wbg from outputting
// anything.
// I'm not entirely sure what `__wasm_call_ctors` is initializing, but it seems to have something to do with C++
// vtables, and it's crucial for proper operation.
extern "C" {
fn __wasm_call_ctors();
}

View File

@@ -3,7 +3,7 @@
use std::path::Path;
use ndarray::{ArrayD, IxDyn};
use ort::{inputs, DynTensor, GraphOptimizationLevel, Session};
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
use test_log::test;
#[test]
@@ -22,7 +22,7 @@ fn vectorizer() -> ort::Result<()> {
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
// Just one input
let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?;
let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?;
// Perform the inference
let outputs = session.run(input_tensor_values)?;