diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 2d590f0..1206860 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -82,7 +82,7 @@ impl Kernel for CustomOpTwoKernel { fn main() -> ort::Result<()> { let session = Session::builder()? - .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? + .with_operators(OperatorDomain::new("test.customop")?.add::()?.add::()?)? .commit_from_file("tests/data/custom_op_test.onnx")?; let values = session.run(ort::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; diff --git a/src/environment.rs b/src/environment.rs index 343cdc9..810f9e9 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -34,7 +34,7 @@ pub struct Environment { } impl Environment { - /// Loads the underlying [`ort_sys::OrtEnv`] pointer. + /// Returns the underlying [`ort_sys::OrtEnv`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtEnv { self.env_ptr.load(Ordering::Relaxed) } @@ -52,13 +52,14 @@ impl Drop for Environment { } } -/// Gets a reference to the global environment, creating one if an environment has been committed yet. +/// Gets a reference to the global environment, creating one if an environment has not been +/// [`commit`](EnvironmentBuilder::commit)ted yet. pub fn get_environment() -> Result<&'static Arc> { if let Some(c) = unsafe { &*G_ENV.cell.get() } { Ok(c) } else { debug!("Environment not yet initialized, creating a new one"); - EnvironmentBuilder::default().commit()?; + EnvironmentBuilder::new().commit()?; Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() }) } @@ -72,7 +73,7 @@ pub struct EnvironmentGlobalThreadPoolOptions { pub intra_op_thread_affinity: Option } -/// Struct used to build an `Environment`. +/// Struct used to build an [`Environment`]; see [`crate::init`]. pub struct EnvironmentBuilder { name: String, telemetry: bool, @@ -80,8 +81,8 @@ pub struct EnvironmentBuilder { global_thread_pool_options: Option } -impl Default for EnvironmentBuilder { - fn default() -> Self { +impl EnvironmentBuilder { + pub(crate) fn new() -> Self { EnvironmentBuilder { name: "default".to_string(), telemetry: true, @@ -89,11 +90,9 @@ impl Default for EnvironmentBuilder { global_thread_pool_options: None } } -} -impl EnvironmentBuilder { /// Configure the environment with a given name for logging purposes. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_name(mut self, name: S) -> Self where S: Into @@ -102,7 +101,17 @@ impl EnvironmentBuilder { self } - #[must_use] + /// Enable or disable sending telemetry events to Microsoft. + /// + /// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled. + /// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled. + /// + /// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.0/onnxruntime/core/platform/windows/telemetry.cc). + /// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or + /// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names, + /// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to + /// better understand how customers use ONNX Runtime and where performance can be improved. + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_telemetry(mut self, enable: bool) -> Self { self.telemetry = enable; self @@ -116,14 +125,14 @@ impl EnvironmentBuilder { /// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built /// with support for the corresponding execution provider. Execution providers that do not have their corresponding /// feature enabled will emit a warning. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self { self.execution_providers = execution_providers.as_ref().to_vec(); self } /// Enables the global thread pool for this environment. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self { self.global_thread_pool_options = Some(options); self @@ -158,14 +167,17 @@ impl EnvironmentBuilder { ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment]; } - ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( + ortsys![ + unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), thread_options, &mut env_ptr - ) -> Error::CreateEnvironment; nonNull(env_ptr)]; + ) -> Error::CreateEnvironment; + nonNull(env_ptr) + ]; ortsys![unsafe ReleaseThreadingOptions(thread_options)]; (env_ptr, true) } else { @@ -174,13 +186,16 @@ impl EnvironmentBuilder { // FIXME: What should go here? let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!()); - ortsys![unsafe CreateEnvWithCustomLogger( + ortsys![ + unsafe CreateEnvWithCustomLogger( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), &mut env_ptr - ) -> Error::CreateEnvironment; nonNull(env_ptr)]; + ) -> Error::CreateEnvironment; + nonNull(env_ptr) + ]; (env_ptr, false) }; debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); @@ -205,15 +220,25 @@ impl EnvironmentBuilder { /// Creates an ONNX Runtime environment. /// +/// ``` +/// # use ort::CUDAExecutionProvider; +/// # fn main() -> ort::Result<()> { +/// ort::init() +/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .commit()?; +/// # Ok(()) +/// # } +/// ``` +/// /// # Notes /// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a /// default environment will be created. -/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it. +/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it. /// - In order for environment settings to apply, this must be called **before** you use other APIs like /// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function. -#[must_use] +#[must_use = "commit() must be called in order for the environment to take effect"] pub fn init() -> EnvironmentBuilder { - EnvironmentBuilder::default() + EnvironmentBuilder::new() } /// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`) @@ -221,15 +246,26 @@ pub fn init() -> EnvironmentBuilder { /// /// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded. /// +/// ```no_run +/// # use ort::CUDAExecutionProvider; +/// # fn main() -> ort::Result<()> { +/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib"); +/// ort::init_from(lib_path.join("onnxruntime.dll")) +/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .commit()?; +/// # Ok(()) +/// # } +/// ``` +/// /// # Notes /// - In order for environment settings to apply, this must be called **before** you use other APIs like /// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function. #[cfg(feature = "load-dynamic")] #[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))] -#[must_use] +#[must_use = "commit() must be called in order for the environment to take effect"] pub fn init_from(path: impl ToString) -> EnvironmentBuilder { let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string())); - EnvironmentBuilder::default() + EnvironmentBuilder::new() } /// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct. @@ -325,7 +361,7 @@ mod tests { assert!(!is_env_initialized()); assert_eq!(env_ptr(), None); - EnvironmentBuilder::default().with_name("env_is_initialized").commit()?; + EnvironmentBuilder::new().with_name("env_is_initialized").commit()?; assert!(is_env_initialized()); assert_ne!(env_ptr(), None); Ok(()) diff --git a/src/error.rs b/src/error.rs index 7bdb2ba..fb25f20 100644 --- a/src/error.rs +++ b/src/error.rs @@ -121,9 +121,6 @@ pub enum Error { /// Error occurred when filling a tensor with string data #[error("Failed to fill string tensor: {0}")] FillStringTensor(ErrorInternal), - /// Error occurred when checking if a value is a tensor - #[error("Failed to check if value is a tensor: {0}")] - FailedTensorCheck(ErrorInternal), /// Error occurred when getting tensor type and shape #[error("Failed to get tensor type and shape: {0}")] GetTensorTypeAndShape(ErrorInternal), @@ -159,12 +156,6 @@ pub enum Error { /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models). #[error("Failed to download ONNX model: {0}")] DownloadError(#[from] FetchModelError), - /// Type of input data and the ONNX model do not match. - #[error("Data types do not match: expected {model:?}, got {input:?}")] - NonMatchingDataTypes { input: TensorElementType, model: TensorElementType }, - /// Dimensions of input data and the ONNX model do not match. - #[error("Dimensions do not match: {0:?}")] - NonMatchingDimensions(NonMatchingDimensionsError), /// File does not exist #[error("File `{filename:?}` does not exist")] FileDoesNotExist { @@ -186,9 +177,6 @@ pub enum Error { /// ORT pointer should not have been null #[error("`{0}` should not be a null pointer")] PointerShouldNotBeNull(&'static str), - /// The runtime type was undefined. - #[error("Undefined tensor element type")] - UndefinedTensorElementType, /// Could not retrieve model metadata. #[error("Failed to retrieve model metadata: {0}")] GetModelMetadata(ErrorInternal), @@ -208,8 +196,8 @@ pub enum Error { ExecutionProviderNotRegistered(&'static str), #[error("Expected tensor to be on CPU in order to get data, but had allocation device `{0}`.")] TensorNotOnCpu(&'static str), - #[error("String tensors require the session's allocator to be provided through `Value::from_array`.")] - StringTensorRequiresAllocator, + #[error("Cannot extract scalar value from a {0}-dimensional tensor")] + TensorNot0Dimensional(usize), #[error("Failed to create memory info: {0}")] CreateMemoryInfo(ErrorInternal), #[error("Could not get allocation device from `MemoryInfo`: {0}")] @@ -222,10 +210,10 @@ pub enum Error { BindInput(ErrorInternal), #[error("Error when binding output: {0}")] BindOutput(ErrorInternal), - #[error("Failed to clear IO binding: {0}")] - ClearBinding(ErrorInternal), #[error("Error when retrieving session outputs from `IoBinding`: {0}")] GetBoundOutputs(ErrorInternal), + #[error("Cannot use `extract_tensor` on a value that is {0:?}")] + NotTensor(ValueType), #[error("Cannot use `extract_sequence` on a value that is {0:?}")] NotSequence(ValueType), #[error("Cannot use `extract_map` on a value that is {0:?}")] @@ -252,6 +240,8 @@ pub enum Error { GetOperatorInput(ErrorInternal), #[error("Failed to get operator output: {0}")] GetOperatorOutput(ErrorInternal), + #[error("Failed to retrieve GPU compute stream from kernel context: {0}")] + GetOperatorGPUComputeStream(ErrorInternal), #[error("{0}")] CustomError(#[from] Box), #[error("String tensors cannot be borrowed as mutable")] @@ -266,37 +256,20 @@ pub enum Error { GetDeviceId(ErrorInternal) } -impl From for Error { - fn from(_: Infallible) -> Self { - Error::Infallible +impl Error { + /// Wrap a custom, user-provided error in an [`ort::Error`](Error). The resulting error will be the + /// [`Error::CustomError`] variant. + /// + /// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort` + /// related operation fails. + pub fn wrap(err: T) -> Self { + Error::CustomError(Box::new(err) as Box) } } -/// Error used when the input dimensions defined in the model and passed from an inference call do not match. -#[non_exhaustive] -#[derive(Error, Debug)] -pub enum NonMatchingDimensionsError { - /// Number of inputs from model does not match the number of inputs from inference call. - #[error( - "Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})" - )] - InputsCount { - /// Number of input dimensions used by inference call - inference_input_count: usize, - /// Number of input dimensions defined in model - model_input_count: usize, - /// Input dimensions used by inference call - inference_input: Vec>, - /// Input dimensions defined in model - model_input: Vec>> - }, - /// Inputs length from model does not match the expected input from inference call - #[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")] - InputsLength { - /// Input dimensions used by inference call - inference_input: Vec>, - /// Input dimensions defined in model - model_input: Vec>> +impl From for Error { + fn from(_: Infallible) -> Self { + Error::Infallible } } diff --git a/src/io_binding.rs b/src/io_binding.rs index a94f291..36b11e6 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -1,6 +1,8 @@ use std::{ + collections::HashMap, ffi::CString, fmt::Debug, + marker::PhantomData, ptr::{self, NonNull}, sync::Arc }; @@ -9,24 +11,87 @@ use crate::{ memory::MemoryInfo, ortsys, session::{output::SessionOutputs, RunOptions}, - value::{Value, ValueRefMut}, - Error, Result, Session, ValueTypeMarker + value::{Value, ValueInner}, + DynValue, Error, Result, Session, ValueTypeMarker }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. /// -/// Note that this arrangement is designed to minimize data copies, and to that effect, your memory allocations must -/// match what is expected by the model, whether you run on CPU or GPU. Data will still be copied if the -/// pre-allocated memory location does not match the one expected by the model. However, copies with `IoBinding`s are -/// only done once, at the time of the binding, not at run time. This means, that if your input data required a copy, -/// your further input modifications would not be seen by ONNX Runtime unless you rebind it, even if it is the same -/// buffer. If your scenario requires that the data is copied, `IoBinding` may not be the best match for your use case. -/// The fact that data copy is not made during runtime may also have performance implications. +/// [`IoBinding`] minimizes copies between a device (like a GPU) and the host (CPU) by allowing the user to bind a +/// certain input/output to a pre-allocated value on a specific device. +/// +/// [`IoBinding`] is most suitable for: +/// - An ensemble of models in which the output from one model is the input of another and does not need to pass through +/// the CPU to perform additional processing. +/// - Situations where the output should stay on a device (e.g. to perform additional processing with CUDA). +/// - Diffusion models, for instance, that accept an unchanging embedding for conditioning. +/// +/// [`IoBinding`] will not provide any meaningful benefit for: +/// - Models where every input changes with each invocation, such as a causal language model or object recognition +/// model. +/// - Pipelines that go straight from CPU -> GPU -> CPU. +/// +/// # Example +/// A diffusion model which takes a text condition input. +/// +/// ```no_run +/// # use ort::{Allocator, AllocatorType, AllocationDevice, CUDAExecutionProvider, MemoryInfo, MemoryType, Session, Tensor, IoBinding}; +/// # fn main() -> ort::Result<()> { +/// let text_encoder = Session::builder()? +/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .commit_from_file("text_encoder.onnx")?; +/// let unet = Session::builder()? +/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .commit_from_file("unet.onnx")?; +/// +/// let text_condition = text_encoder +/// .run(ort::inputs![Tensor::::from_array(( +/// vec![27], +/// vec![ +/// 23763, 15460, 473, 68, 312, 265, 17463, 4098, 304, 1077, 283, 198, 7676, 5976, 272, 285, 3609, 435, +/// 21680, 321, 265, 300, 1689, 64, 285, 4763, 64 +/// ] +/// ))?]?)? +/// .remove("output0") +/// .unwrap(); +/// +/// let input_allocator = Allocator::new( +/// &unet, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? +/// )?; +/// let mut latents = Tensor::::new(&input_allocator, [1, 4, 64, 64])?; +/// +/// let mut io_binding = unet.create_binding()?; +/// io_binding.bind_input("condition", &text_condition)?; +/// +/// let output_allocator = Allocator::new( +/// &unet, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUOutput)? +/// )?; +/// io_binding.bind_output("noise_pred", Tensor::::new(&output_allocator, [1, 4, 64, 64])?)?; +/// +/// for _ in 0..20 { +/// io_binding.bind_input("latents", &latents)?; +/// let noise_pred = io_binding.run()?.remove("noise_pred").unwrap(); +/// +/// let mut latents = latents.extract_tensor_mut(); +/// latents += &noise_pred.try_extract_tensor::()?; +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// [`IoBinding`] may provide a decent speedup in this example since the `condition` tensor is unchanging between runs. +/// If we were to use normal session inference, the `condition` tensor would be needlessly copied with each invocation +/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition` +/// tensor is only copied to the device once instead of 20 times. #[derive(Debug)] pub struct IoBinding<'s> { pub(crate) ptr: NonNull, session: &'s Session, - output_names: Vec + held_inputs: HashMap>, + output_names: Vec, + output_values: HashMap } impl<'s> IoBinding<'s> { @@ -36,25 +101,47 @@ impl<'s> IoBinding<'s> { Ok(Self { ptr: unsafe { NonNull::new_unchecked(ptr) }, session, - output_names: Vec::new() + held_inputs: HashMap::new(), + output_names: Vec::new(), + output_values: HashMap::new() }) } /// Bind a [`Value`] to a session input. - pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'i mut Value) -> Result> { + /// + /// Upon invocation, the value's data will be copied to the device the session is allocated on. The copied data will + /// be used as an input (specified by `name`) in all future invocations of [`IoBinding::run`] until the input is + /// overridden (by calling [`IoBinding::bind_input`] again) or until all inputs are cleared (via + /// [`IoBinding::clear_inputs`] or [`IoBinding::clear`]). + /// + /// The data is only copied **once**, immediately upon invocation of this function. Any changes to the given + /// value afterwards will not affect the data seen by the session until the value is re-bound. Subsequent re-binds + /// will still copy data, hence why [`IoBinding`] is really only suitable when one or more inputs do not change + /// between runs. + pub fn bind_input>(&mut self, name: S, ort_value: &Value) -> Result<()> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput]; - Ok(ort_value.view_mut()) + self.held_inputs.insert(name.to_string(), Arc::clone(&ort_value.inner)); + Ok(()) } /// Bind a session output to a pre-allocated [`Value`]. - pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'o mut Value) -> Result> { + /// + /// This allows for the pre-allocation and reuse of memory in the session output (see [`crate::Tensor::new`]). Any + /// subsequent runs via [`IoBinding::run`] will reuse the same tensor to store the output instead of creating a new + /// one each time. + /// + /// The output will be accessible in the value returned by [`IoBinding::run`], under the name specified by `name`. + pub fn bind_output>(&mut self, name: S, ort_value: Value) -> Result<()> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput]; self.output_names.push(name.to_string()); - Ok(ort_value.view_mut()) + // Clear the old bound output if we have any. + drop(self.output_values.remove(name)); + self.output_values.insert(name.to_string(), ort_value.into_dyn()); + Ok(()) } /// Bind a session output to a device which is specified by `mem_info`. @@ -66,15 +153,35 @@ impl<'s> IoBinding<'s> { Ok(()) } - pub fn run<'i: 's>(&'i self) -> Result> { + /// Clears all bound inputs specified by [`IoBinding::bind_input`]. + pub fn clear_inputs(&mut self) { + ortsys![unsafe ClearBoundInputs(self.ptr.as_ptr())]; + drop(self.held_inputs.drain()); + } + /// Clears all bound outputs specified by [`IoBinding::bind_output`] or [`IoBinding::bind_output_to_device`]. + pub fn clear_outputs(&mut self) { + ortsys![unsafe ClearBoundOutputs(self.ptr.as_ptr())]; + drop(self.output_names.drain(..)); + drop(self.output_values.drain()); + } + /// Clears both the bound inputs & outputs; equivalent to [`IoBinding::clear_inputs`] followed by + /// [`IoBinding::clear_outputs`]. + pub fn clear(&mut self) { + self.clear_inputs(); + self.clear_outputs(); + } + + /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. + pub fn run(&mut self) -> Result> { self.run_inner(None) } - pub fn run_with_options<'i: 's>(&'i self, run_options: Arc) -> Result> { + /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. + pub fn run_with_options(&mut self, run_options: Arc) -> Result> { self.run_inner(Some(run_options)) } - fn run_inner<'i: 's>(&'i self, run_options: Option>) -> Result> { + fn run_inner(&mut self, run_options: Option>) -> Result> { let run_options_ptr = if let Some(run_options) = run_options { run_options.run_options_ptr.as_ptr() } else { @@ -82,6 +189,7 @@ impl<'s> IoBinding<'s> { }; ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr()) -> Error::SessionRunWithIoBinding]; + let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect(); let mut count = self.output_names.len() as ort_sys::size_t; if count > 0 { let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); @@ -91,10 +199,17 @@ impl<'s> IoBinding<'s> { let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() } .into_iter() .map(|v| unsafe { - Value::from_ptr( - NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), - Some(Arc::clone(&self.session.inner)) - ) + if let Some(inner) = owned_ptrs.get(&v) { + DynValue { + inner: Arc::clone(*inner), + _markers: PhantomData + } + } else { + DynValue::from_ptr( + NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), + Some(Arc::clone(&self.session.inner)) + ) + } }); // output values will be freed when the `Value`s in `SessionOutputs` drop diff --git a/src/lib.rs b/src/lib.rs index ed5a2ad..a071b86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ pub use self::session::{ #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; -pub use self::tensor::{IntoTensorElementType, TensorElementType}; +pub use self::tensor::{IntoTensorElementType, Utf8Data, PrimitiveTensorElementType, TensorElementType}; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, @@ -143,6 +143,23 @@ pub(crate) static G_ORT_API: OnceLock> = OnceLock::ne /// May panic if: /// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. /// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +/// +/// # Examples +/// The primary (public-facing) use case for this function is accessing APIs that do not have a corresponding safe +/// implementation in `ort`. For example, [`GetBuildInfoString`](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0a7dba37b0017c0ef3a0ab4e266a967d): +/// +/// ``` +/// # use std::ffi::CStr; +/// # fn main() -> ort::Result<()> { +/// let api = ort::api().as_ptr(); +/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) }; +/// println!("{}", build_info.to_string_lossy()); +/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED +/// # Ok(()) +/// # } +/// ``` +/// +/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html). pub fn api() -> NonNull { unsafe { NonNull::new_unchecked( @@ -252,6 +269,26 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result { .map_err(Error::FfiStringConversion) } +pub(crate) struct PrivateTraitMarker; + +macro_rules! private_trait { + () => { + #[doc(hidden)] + #[allow(private_interfaces)] + fn _private() -> crate::PrivateTraitMarker; + }; +} +macro_rules! private_impl { + () => { + #[allow(private_interfaces)] + fn _private() -> crate::PrivateTraitMarker { + crate::PrivateTraitMarker + } + }; +} +pub(crate) use private_impl; +pub(crate) use private_trait; + #[cfg(test)] mod test { use std::ffi::CString; diff --git a/src/memory.rs b/src/memory.rs index 00464f2..bc7644c 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,20 +1,75 @@ use std::{ ffi::{c_char, c_int, CString}, - ptr::NonNull + ptr::NonNull, + sync::Arc }; use super::{ error::{Error, Result}, ortsys }; -use crate::{char_p_to_string, error::status_to_result, Session}; +use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner}; -/// An ONNX Runtime allocator, used to manage the allocation of [`crate::Value`]s. +/// A device allocator used to manage the allocation of [`crate::Value`]s. +/// +/// # Direct allocation +/// [`Allocator`] can be used to directly allocate device memory. This can be useful if you have a +/// postprocessing step that runs on the GPU. +/// ```no_run +/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let allocator = Allocator::new( +/// &session, +/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? +/// )?; +/// +/// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; +/// // Here, `data_ptr` is a pointer to **device memory** inaccessible to the CPU; we'll need another crate, like +/// // `cudarc`, to access it. +/// let data_ptr = tensor.data_ptr_mut()?; +/// # Ok(()) +/// # } +/// ``` +/// +/// Note that `ort` does not facilitate the transfer of data between host & device outside of session inputs & +/// outputs; you'll need to use a separate crate for that, like [`cudarc`](https://crates.io/crates/cudarc) for CUDA. +/// +/// # Pinned allocation +/// Memory allocated on the host CPU is often *pageable* and may reside on the disk (swap memory). Transferring +/// pageable memory to another device is slow because the device has to go through the CPU to access the +/// memory. Many execution providers thus provide a "pinned" allocator type, which allocates *unpaged* CPU memory +/// that the device is able to access directly, bypassing the CPU and allowing for faster host-to-device data +/// transfer. +/// +/// If you create a session with a device allocator that supports pinned memory, like CUDA or ROCm, you can create +/// an allocator for that session, and use it to allocate tensors with faster pinned memory: +/// ```no_run +/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let allocator = Allocator::new( +/// &session, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? +/// )?; +/// +/// // Create a tensor with our pinned allocator. +/// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; +/// let data = tensor.extract_tensor_mut(); +/// // ...fill `data` with data... +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] pub struct Allocator { pub(crate) ptr: NonNull, + /// The 'default' CPU allocator, provided by `GetAllocatorWithDefaultOptions` and implemented by + /// [`Allocator::default`], should **not** be released, so this field marks whether or not we should call + /// `ReleaseAllocator` on drop. is_default: bool, - _info: Option + _info: Option, + /// Hold a reference to the session if this allocator is tied to one. + _session_inner: Option> } impl Allocator { @@ -22,47 +77,46 @@ impl Allocator { Allocator { ptr: NonNull::new_unchecked(ptr), is_default: false, + // currently, this function is only ever used in session creation, where we call `CreateAllocator` manually and store the allocator resulting from + // this function in the `SharedSessionInner` - we don't need to hold onto the session, because the session is holding onto us. + _session_inner: None, _info: None } } + /// Frees an object allocated by this allocator, given the object's C pointer. pub(crate) unsafe fn free(&self, ptr: *mut T) { self.ptr.as_ref().Free.unwrap_or_else(|| unreachable!("Allocator method `Free` is null"))(self.ptr.as_ptr(), ptr.cast()); } /// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the /// [`MemoryInfo`]. - /// - /// For example, to create an allocator to allocate pinned memory for CUDA: - /// ```no_run - /// # use ort::{Allocator, Session, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// let allocator = Allocator::new( - /// &session, - /// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? - /// )?; - /// # Ok(()) - /// # } - /// ``` pub fn new(session: &Session, memory_info: MemoryInfo) -> Result { let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; Ok(Self { ptr: unsafe { NonNull::new_unchecked(allocator_ptr) }, is_default: false, + _session_inner: Some(session.inner()), _info: Some(memory_info) }) } } impl Default for Allocator { + /// Returns the default CPU allocator; equivalent to `MemoryInfo::new(AllocationDevice::CPU, 0, + /// AllocatorType::Device, MemoryType::Default)`. + /// + /// The allocator returned by this function is actually shared across all invocations (though this behavior is + /// transparent to the user). fn default() -> Self { let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); status_to_result(ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]).expect("Failed to get default allocator"); Self { ptr: unsafe { NonNull::new_unchecked(allocator_ptr) }, is_default: true, + // The default allocator isn't tied to a session. + _session_inner: None, _info: None } } @@ -70,8 +124,6 @@ impl Default for Allocator { impl Drop for Allocator { fn drop(&mut self) { - // per GetAllocatorWithDefaultOptions docs: Returned value should NOT be freed - // https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a8dec797ae52ee1a681e4f88be1fb4bb3 if !self.is_default { ortsys![unsafe ReleaseAllocator(self.ptr.as_ptr())]; } @@ -81,7 +133,8 @@ impl Drop for Allocator { /// Represents possible devices that have their own device allocator. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocationDevice { - // https://github.com/microsoft/onnxruntime/blob/v1.17.0/include/onnxruntime/core/framework/allocator.h#L43-L53 + // https://github.com/microsoft/onnxruntime/blob/v1.18.0/include/onnxruntime/core/framework/allocator.h#L43-L53 + // ort will likely never support WebGPU, so I think it's best to leave `WebGPU_Buffer` out entirely to reduce confusion CPU, CUDA, CUDAPinned, @@ -91,12 +144,10 @@ pub enum AllocationDevice { HIP, HIPPinned, OpenVINOCPU, - OpenVINOGPU, - WebGPUBuffer + OpenVINOGPU } impl AllocationDevice { - #[must_use] pub fn as_str(&self) -> &'static str { match self { Self::CPU => "Cpu", @@ -108,10 +159,15 @@ impl AllocationDevice { Self::HIP => "Hip", Self::HIPPinned => "HipPinned", Self::OpenVINOCPU => "OpenVINO_CPU", - Self::OpenVINOGPU => "OpenVINO_GPU", - Self::WebGPUBuffer => "WebGPU_Buffer" + Self::OpenVINOGPU => "OpenVINO_GPU" } } + + /// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device, + /// it could be extracted to an `ndarray` or slice. + pub fn is_cpu_accessible(&self) -> bool { + matches!(self, Self::CPU | Self::CUDAPinned | Self::CANNPinned | Self::HIPPinned | Self::OpenVINOCPU) + } } impl TryFrom for AllocationDevice { @@ -129,14 +185,13 @@ impl TryFrom for AllocationDevice { "HipPinned" => Ok(AllocationDevice::HIPPinned), "OpenVINO_CPU" => Ok(AllocationDevice::OpenVINOCPU), "OpenVINO_GPU" => Ok(AllocationDevice::OpenVINOGPU), - "WebGPUBuffer" => Ok(AllocationDevice::WebGPUBuffer), _ => Err(value) } } } /// Execution provider allocator type. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocatorType { /// Default device-specific allocator. Device, @@ -154,11 +209,11 @@ impl From for ort_sys::OrtAllocatorType { } /// Memory types for allocated memory. -#[derive(Default, Debug, Copy, Clone)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] pub enum MemoryType { /// Any CPU memory used by non-CPU execution provider. CPUInput, - /// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED. + /// CPU-accessible memory output by a non-CPU execution provider, i.e. [`AllocatorDevice::CUDAPinned`]. CPUOutput, /// The default allocator for an execution provider. #[default] @@ -190,6 +245,12 @@ impl From for MemoryType { } } +/// Structure describing a memory location - the device on which the memory resides, the type of allocator (device +/// default, or arena) used, and the type of memory allocated (device-only, or CPU accessible). +/// +/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which +/// device value data should reside, and how that data should be accessible with regard to the CPU (if a non-CPU device +/// is requested). #[derive(Debug)] pub struct MemoryInfo { pub(crate) ptr: NonNull, @@ -197,24 +258,24 @@ pub struct MemoryInfo { } impl MemoryInfo { - pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { - MemoryInfo { ptr, should_release } - } - - #[tracing::instrument] - pub fn new_cpu(allocator: AllocatorType, memory_type: MemoryType) -> Result { - let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); - ortsys![ - unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr) -> Error::CreateMemoryInfo; - nonNull(memory_info_ptr) - ]; - Ok(Self { - ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) }, - should_release: true - }) - } - - #[tracing::instrument] + /// Creates a [`MemoryInfo`], describing a memory location on a device allocator. + /// + /// # Examples + /// `MemoryInfo` can be used to specify the device & memory type used by an [`Allocator`] to allocate tensors. + /// See [`Allocator`] for more information & potential applications. + /// ```no_run + /// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// + /// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; + /// # Ok(()) + /// # } + /// ``` pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result { let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!()); @@ -229,7 +290,19 @@ impl MemoryInfo { }) } + pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { + MemoryInfo { ptr, should_release } + } + /// Returns the [`MemoryType`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.memory_type()?, MemoryType::Default); + /// # Ok(()) + /// # } + /// ``` pub fn memory_type(&self) -> Result { let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault; ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetMemoryType]; @@ -237,6 +310,14 @@ impl MemoryInfo { } /// Returns the [`AllocatorType`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.allocator_type()?, AllocatorType::Device); + /// # Ok(()) + /// # } + /// ``` pub fn allocator_type(&self) -> Result { let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator; ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetAllocatorType]; @@ -248,6 +329,14 @@ impl MemoryInfo { } /// Returns the [`AllocationDevice`] this struct was created with. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.allocation_device()?, AllocationDevice::CPU); + /// # Ok(()) + /// # } + /// ``` pub fn allocation_device(&self) -> Result { let mut name_ptr: *const c_char = std::ptr::null_mut(); ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr) -> Error::GetAllocationDevice; nonNull(name_ptr)]; @@ -258,6 +347,14 @@ impl MemoryInfo { } /// Returns the ID of the [`AllocationDevice`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.device_id()?, 0); + /// # Ok(()) + /// # } + /// ``` pub fn device_id(&self) -> Result { let mut raw: ort_sys::c_int = 0; ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw) -> Error::GetDeviceId]; @@ -266,24 +363,9 @@ impl MemoryInfo { } impl Drop for MemoryInfo { - #[tracing::instrument] fn drop(&mut self) { if self.should_release { ortsys![unsafe ReleaseMemoryInfo(self.ptr.as_ptr())]; } } } - -#[cfg(test)] -mod tests { - use test_log::test; - - use super::*; - - #[test] - fn create_memory_info() -> crate::Result<()> { - let memory_info = MemoryInfo::new_cpu(AllocatorType::Device, MemoryType::Default)?; - std::mem::drop(memory_info); - Ok(()) - } -} diff --git a/src/operator/bound.rs b/src/operator/bound.rs index f46c6c1..58219c4 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -11,8 +11,7 @@ use super::{ }; use crate::error::IntoStatus; -#[repr(C)] -#[derive(Clone)] +#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later pub(crate) struct BoundOperator { implementation: ort_sys::OrtCustomOp, name: CString, @@ -184,7 +183,10 @@ unsafe impl Send for ErasedBoundOperator {} impl ErasedBoundOperator { pub(crate) fn new(bound: BoundOperator) -> Self { - ErasedBoundOperator(NonNull::from(unsafe { &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) })) + ErasedBoundOperator(NonNull::from(unsafe { + // horrible horrible horrible horrible horrible horrible horrible horrible horrible + &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) + })) } pub(crate) fn op_ptr(&self) -> *mut ort_sys::OrtCustomOp { diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 61758f6..db09aa2 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -138,4 +138,13 @@ impl KernelContext { ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx as ort_sys::size_t, shape.as_ptr(), shape.len() as _, &mut value_ptr) -> Error::GetOperatorOutput]; Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) }))) } + + /// Returns a pointer to the GPU compute stream (i.e. `cudaStream_t`) used by the execution provider, if this + /// kernel's operator was configured to use said execution provider (see + /// [`super::Operator::execution_provider_type`]). + pub fn compute_stream(&self) -> Result>> { + let mut stream_ptr: *mut ort_sys::c_void = ptr::null_mut(); + ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetOperatorGPUComputeStream]; + Ok(NonNull::new(stream_ptr)) + } } diff --git a/src/operator/mod.rs b/src/operator/mod.rs index 7e87351..ad361f2 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -16,11 +16,26 @@ use crate::{operator::bound::BoundOperator, ortsys, Error, Result}; pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>; +/// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator. +/// +/// [`Operator`]s are bound to [`OperatorDomain`]s. Multiple operators can have the same name as long as they have +/// different input/output types, in which case the exact operator will be picked depending on the input/output +/// types. If you want to, for example, define a `Sort` operator that can accept either a single `f32` or `i64` tensor +/// input, you'll need to define 2 separate operators (which can be done via a macro); but both of these +/// [`Operator`] structs can return the same name in [`Operator::name`] so that they are usable as simply +/// `my.domain:Sort` in the graph. pub trait Operator: Send { type Kernel: Kernel; + /// Returns the name of the operator. fn name() -> &'static str; + /// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`. + /// + /// If the returned type is not `None`, and the execution provider used by the session matches this operator's + /// EP type, the value will not be copied to the CPU and you may use functions like [`crate::Tensor::data_ptr`] to + /// access the underlying device memory, and [`super::KernelContext::compute_stream`] to access the GPU compute + /// stream. fn execution_provider_type() -> Option<&'static str> { None } @@ -42,6 +57,7 @@ pub trait Operator: Send { } } +/// Dummy type implementing [`Operator`] used by [`ErasedBoundOperator`] to cheat the type system. struct DummyOperator; impl Operator for DummyOperator { @@ -84,7 +100,7 @@ impl OperatorDomain { } #[allow(clippy::should_implement_trait)] - pub fn add(mut self, _operator: O) -> Result { + pub fn add(mut self) -> Result { let name = O::name(); let bound = BoundOperator::::new(CString::new(name)?, O::execution_provider_type().map(CString::new).transpose()?); diff --git a/src/session/input.rs b/src/session/input.rs index ce33f00..61d55e5 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -92,16 +92,15 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs< /// # } /// ``` /// -/// Note that string tensors must be created manually with [`Value::from_string_array`]. +/// Note that string tensors must be created manually with [`crate::Tensor::from_string_array`]. /// /// ```no_run /// # use std::{error::Error, sync::Arc}; /// # use ndarray::Array1; -/// # use ort::{GraphOptimizationLevel, Session, Value}; +/// # use ort::{GraphOptimizationLevel, Session, Tensor}; /// # fn main() -> Result<(), Box> { /// # let mut session = Session::builder()?.commit_from_file("model.onnx")?; -/// let _ = session -/// .run(ort::inputs![Value::from_string_array(session.allocator(), Array1::from_vec(vec!["hello", "world"]))?]?); +/// let _ = session.run(ort::inputs![Tensor::from_string_array(Array1::from_vec(vec!["hello", "world"]))?]?); /// # Ok(()) /// # } /// ``` diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 98cf2ae..a1a5440 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -22,4 +22,4 @@ mod types; pub use self::ndarray::ArrayExtensions; #[cfg(feature = "ndarray")] pub(crate) use self::types::{extract_primitive_array, extract_primitive_array_mut}; -pub use self::types::{IntoTensorElementType, TensorElementType, Utf8Data}; +pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; diff --git a/src/tensor/types.rs b/src/tensor/types.rs index 08a57d8..aabe683 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -91,6 +91,12 @@ impl From for TensorElementType { pub trait IntoTensorElementType { /// Returns the ONNX tensor element data type corresponding to the given Rust type. fn into_tensor_element_type() -> TensorElementType; + + crate::private_trait!(); +} + +pub trait PrimitiveTensorElementType: IntoTensorElementType { + crate::private_trait!(); } macro_rules! impl_type_trait { @@ -99,6 +105,12 @@ macro_rules! impl_type_trait { fn into_tensor_element_type() -> TensorElementType { TensorElementType::$variant } + + crate::private_impl!(); + } + + impl PrimitiveTensorElementType for $type_ { + crate::private_impl!(); } }; } @@ -121,6 +133,14 @@ impl_type_trait!(u64, Uint64); #[cfg_attr(docsrs, doc(cfg(feature = "half")))] impl_type_trait!(half::bf16, Bfloat16); +impl IntoTensorElementType for String { + fn into_tensor_element_type() -> TensorElementType { + TensorElementType::String + } + + crate::private_impl!(); +} + /// Adapter for common Rust string types to ONNX strings. pub trait Utf8Data { /// Returns the contents of this value as a slice of UTF-8 bytes. diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 1421e5a..f638787 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -3,25 +3,39 @@ use std::{ fmt::Debug, hash::Hash, marker::PhantomData, - ptr::{self, NonNull} + ptr::{self, NonNull}, + sync::Arc }; use super::{ValueInner, ValueTypeMarker}; use crate::{ - memory::Allocator, ortsys, value::impl_tensor::DynTensor, DynValue, Error, IntoTensorElementType, Result, Tensor, Value, ValueRef, ValueRefMut, ValueType + memory::Allocator, + ortsys, + value::impl_tensor::{calculate_tensor_size, DynTensor}, + DynValue, Error, IntoTensorElementType, PrimitiveTensorElementType, Result, Tensor, TensorElementType, Value, ValueRef, ValueRefMut, ValueType }; -pub trait MapValueTypeMarker: ValueTypeMarker {} +pub trait MapValueTypeMarker: ValueTypeMarker { + crate::private_trait!(); +} #[derive(Debug)] pub struct DynMapValueType; -impl ValueTypeMarker for DynMapValueType {} -impl MapValueTypeMarker for DynMapValueType {} +impl ValueTypeMarker for DynMapValueType { + crate::private_impl!(); +} +impl MapValueTypeMarker for DynMapValueType { + crate::private_impl!(); +} #[derive(Debug)] pub struct MapValueType(PhantomData<(K, V)>); -impl ValueTypeMarker for MapValueType {} -impl MapValueTypeMarker for MapValueType {} +impl ValueTypeMarker for MapValueType { + crate::private_impl!(); +} +impl MapValueTypeMarker for MapValueType { + crate::private_impl!(); +} pub type DynMap = Value; pub type Map = Value>; @@ -32,10 +46,7 @@ pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType>; pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType>; impl Value { - pub fn try_extract_map( - &self, - allocator: &Allocator - ) -> Result> { + pub fn try_extract_map(&self) -> Result> { match self.dtype()? { ValueType::Map { key, value } => { let k_type = K::into_tensor_element_type(); @@ -47,47 +58,95 @@ impl Value { return Err(Error::InvalidMapValueType { expected: v_type, actual: value }); } + let allocator = Allocator::default(); + let mut key_tensor_ptr = ptr::null_mut(); ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr) -> Error::ExtractMap; nonNull(key_tensor_ptr)]; let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) }; - let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_tensor::()?; + if K::into_tensor_element_type() != TensorElementType::String { + let dtype = key_value.dtype()?; + let (key_tensor_shape, key_tensor) = match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = key_value.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let mut value_tensor_ptr = ptr::null_mut(); - ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; - let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; - let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + if ty == K::into_tensor_element_type() { + let mut output_array_ptr: *mut K = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - assert_eq!(key_tensor_shape.len(), 1); - assert_eq!(value_tensor_shape.len(), 1); - assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + let len = calculate_tensor_size(&dimensions); + (dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }) + } else { + return Err(Error::DataTypeMismatch { + actual: ty, + requested: K::into_tensor_element_type() + }); + } + } + _ => unreachable!() + }; - let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); - for i in 0..key_tensor_shape[0] as usize { - vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + let mut value_tensor_ptr = ptr::null_mut(); + ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + + assert_eq!(key_tensor_shape.len(), 1); + assert_eq!(value_tensor_shape.len(), 1); + assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + + let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); + for i in 0..key_tensor_shape[0] as usize { + vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + } + Ok(vec.into_iter().collect()) + } else { + let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_string_tensor()?; + // SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`, + // so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type + // checker. + let key_tensor: Vec = unsafe { std::mem::transmute(key_tensor) }; + + let mut value_tensor_ptr = ptr::null_mut(); + ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + + assert_eq!(key_tensor_shape.len(), 1); + assert_eq!(value_tensor_shape.len(), 1); + assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + + let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); + for i in 0..key_tensor_shape[0] as usize { + vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + } + Ok(vec.into_iter().collect()) } - Ok(vec.into_iter().collect()) } t => Err(Error::NotMap(t)) } } } -impl Value> { +impl Value> { /// Creates a [`Map`] from an iterable emitting `K` and `V`. /// /// ``` /// # use std::collections::HashMap; - /// # use ort::{Allocator, Map}; + /// # use ort::Map; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let mut map = HashMap::::new(); /// map.insert(0, 1.0); /// map.insert(1, 2.0); /// map.insert(2, 3.0); /// - /// let value = Map::new(map)?; + /// let value = Map::::new(map)?; /// - /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); /// # Ok(()) /// # } /// ``` @@ -95,20 +154,45 @@ impl, Vec) = data.into_iter().unzip(); Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) } +} +impl Value> { + /// Creates a [`Map`] from an iterable emitting `K` and `V`. + /// + /// ``` + /// # use std::collections::HashMap; + /// # use ort::Map; + /// # fn main() -> ort::Result<()> { + /// let mut map = HashMap::::new(); + /// map.insert(0, 1.0); + /// map.insert(1, 2.0); + /// map.insert(2, 3.0); + /// + /// let value = Map::::new(map)?; + /// + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); + /// # Ok(()) + /// # } + /// ``` + pub fn new(data: impl IntoIterator) -> Result { + let (keys, values): (Vec, Vec) = data.into_iter().unzip(); + Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) + } +} + +impl Value> { /// Creates a [`Map`] from two tensors of keys & values respectively. /// /// ``` /// # use std::collections::HashMap; - /// # use ort::{Allocator, Map, Tensor}; + /// # use ort::{Map, Tensor}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let keys = Tensor::::from_array(([4], vec![0, 1, 2, 3]))?; /// let values = Tensor::::from_array(([4], vec![1., 2., 3., 4.]))?; /// /// let value = Map::new_kv(keys, values)?; /// - /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); /// # Ok(()) /// # } /// ``` @@ -122,21 +206,23 @@ impl Value> { - pub fn extract_map(&self, allocator: &Allocator) -> HashMap { - self.try_extract_map(allocator).expect("Failed to extract map") +impl Value> { + pub fn extract_map(&self) -> HashMap { + self.try_extract_map().expect("Failed to extract map") } +} +impl Value> { /// Converts from a strongly-typed [`Map`] to a type-erased [`DynMap`]. #[inline] pub fn upcast(self) -> DynMap { @@ -149,7 +235,7 @@ impl(PhantomData); -impl ValueTypeMarker for SequenceValueType {} -impl SequenceValueTypeMarker for SequenceValueType {} +impl ValueTypeMarker for SequenceValueType { + crate::private_impl!(); +} +impl SequenceValueTypeMarker for SequenceValueType { + crate::private_impl!(); +} pub type DynSequence = Value; pub type Sequence = Value>; @@ -89,11 +100,11 @@ impl Value Value Value { + /// Construct a [`DynTensor`] from an array of strings. /// - /// Just like numeric tensors, string tensor `Value`s can be created from: + /// Just like numeric tensors, string tensors can be created from: /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); /// - (with feature `ndarray`) an owned [`ndarray::Array`]; @@ -36,26 +36,19 @@ impl DynTensor { /// ``` /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?; - /// // You'll need to obtain an `Allocator` from a session in order to create string tensors. - /// let allocator = session.allocator(); - /// /// // Create a string tensor from a raw data vector /// let data = vec!["hello", "world"]; - /// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?; + /// let value = Value::from_string_array(([data.len()], data.into_boxed_slice()))?; /// /// // Create a string tensor from an `ndarray::Array` /// #[cfg(feature = "ndarray")] - /// let value = Value::from_string_array( - /// allocator, - /// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap() - /// )?; + /// let value = Value::from_string_array(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?; /// # Ok(()) /// # } /// ``` /// /// Note that string data will *always* be copied, no matter what form the data is provided in. - pub fn from_string_array(allocator: &Allocator, input: impl IntoValueTensor) -> Result { + pub fn from_string_array(input: impl IntoValueTensor) -> Result> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); let (shape, data) = input.ref_parts()?; @@ -64,7 +57,7 @@ impl DynTensor { // create tensor without data -- data is filled in later ortsys![ - unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) + unsafe CreateTensorAsOrtValue(Allocator::default().ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) -> Error::CreateTensor; nonNull(value_ptr) ]; @@ -84,18 +77,18 @@ impl DynTensor { ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - }, + }), _markers: PhantomData }) } } -impl Tensor { - /// Construct a tensor [`Value`] in a given allocator with a given shape and datatype. The data contained in the +impl Tensor { + /// Construct a tensor in a given allocator with a given shape and datatype. The data contained in the /// value will be zero-allocated on the allocation device. /// /// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned @@ -132,18 +125,18 @@ impl Tensor { ]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - }, + }), _markers: PhantomData }) } - /// Construct a tensor [`Value`] from an array of data. + /// Construct a tensor from an array of data. /// - /// Tensor `Value`s can be created from: + /// Tensors can be created from: /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); /// - (with feature `ndarray`) an owned [`ndarray::Array`]; @@ -154,19 +147,19 @@ impl Tensor { /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. /// /// ``` - /// # use ort::Value; + /// # use ort::Tensor; /// # fn main() -> ort::Result<()> { /// // Create a tensor from a raw data vector - /// let value = Value::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?; + /// let tensor = Tensor::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?; /// /// // Create a tensor from an `ndarray::Array` /// #[cfg(feature = "ndarray")] - /// let value = Value::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; + /// let tensor = Tensor::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; /// # Ok(()) /// # } /// ``` /// - /// Creating string tensors requires a separate method; see [`Value::from_string_array`]. + /// Creating string tensors requires a separate method; see [`DynTensor::from_string_array`]. /// /// Note that data provided in an `ndarray` may be copied in some circumstances: /// - `&CowArray<'_, T, D>` will always be copied regardless of whether it is uniquely owned or borrowed. @@ -177,7 +170,7 @@ impl Tensor { /// Raw data provided as a `Arc>`, `Box<[T]>`, or `Vec` will never be copied. Raw data is expected to be /// in standard, contigous layout. pub fn from_array(input: impl IntoValueTensor) -> Result> { - let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?; + let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::Default)?; let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); @@ -203,17 +196,17 @@ impl Tensor { ]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: guard, _memory_info: Some(memory_info) - }, + }), _markers: PhantomData }) } } -impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> { +impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// Create a mutable tensor view from a raw pointer and shape. /// /// The length of data is determined by `T` and the given shape, so the given buffer must be at least @@ -260,11 +253,11 @@ impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> { ]; Ok(TensorRefMut::new(Value { - inner: ValueInner::CppOwned { + inner: Arc::new(ValueInner::CppOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, drop: true, _session: None - }, + }), _markers: PhantomData })) } @@ -290,7 +283,7 @@ macro_rules! impl_to_dimensions { .enumerate() .map(|(i, c)| if *c >= 1 { Ok(*c as i64) } else { Err(Error::InvalidDimension(i)) }) .collect::>()?; - let sum = v.iter().product::() as usize; + let sum = calculate_tensor_size(&v); if let Some(expected_size) = expected_size { if sum != expected_size { Err(Error::TensorShapeMismatch { @@ -318,6 +311,14 @@ macro_rules! impl_to_dimensions { }; } +impl ToDimensions for () { + fn to_dimensions(&self, expected_size: Option) -> Result> { + match expected_size { + Some(1) | None => Ok(vec![]), + Some(x) => Err(Error::TensorShapeMismatch { input: vec![], total: 1, expected: x }) + } + } +} impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec, for Vec, for Vec); impl_to_dimensions!( for [usize; N], for [i32; N], for [i64; N]); @@ -500,7 +501,7 @@ impl IntoValueTensor for (D, Arc TryFrom<&'i CowArray<'v, T, D>> for Tensor +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor where 'i: 'v { @@ -512,7 +513,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr) @@ -521,7 +522,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor where 'i: 'v { @@ -533,7 +534,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr).map(|c| c.upcast()) @@ -542,7 +543,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue where 'i: 'v { @@ -554,7 +555,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr).map(|c| c.into_dyn()) @@ -564,19 +565,19 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta macro_rules! impl_try_from { (@T,I $($t:ty),+) => { $( - impl TryFrom<$t> for Tensor { + impl TryFrom<$t> for Tensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value) } } - impl TryFrom<$t> for DynTensor { + impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.upcast()) } } - impl TryFrom<$t> for crate::DynValue { + impl TryFrom<$t> for crate::DynValue { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.into_dyn()) @@ -587,21 +588,21 @@ macro_rules! impl_try_from { (@T,D $($t:ty),+) => { $( #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for Tensor { + impl TryFrom<$t> for Tensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for DynTensor { + impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.upcast()) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for crate::DynValue { + impl TryFrom<$t> for crate::DynValue { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.into_dyn()) diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index a4859e7..22a52ed 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, os::raw::c_char, ptr, string::FromUtf8Error}; +use std::{fmt::Debug, ptr, string::FromUtf8Error}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; @@ -7,9 +7,7 @@ use super::TensorValueTypeMarker; #[cfg(feature = "ndarray")] use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; use crate::{ - ortsys, - tensor::{IntoTensorElementType, TensorElementType}, - Error, Result, Tensor, Value + ortsys, tensor::TensorElementType, value::impl_tensor::calculate_tensor_size, Error, PrimitiveTensorElementType, Result, Tensor, Value, ValueType }; impl Value { @@ -38,38 +36,81 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_tensor`] instead)* /// - The provided type `T` does not match the tensor's element type. + /// - The tensor's data is not allocated in CPU memory. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn try_extract_tensor(&self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + pub fn try_extract_tensor(&self) -> Result> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok(extract_primitive_array(shape, self.ptr())?) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + if ty == T::into_tensor_element_type() { + Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } + } + + /// Attempt to extract the scalar from a tensor of type `T`. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Value}; + /// # fn main() -> ort::Result<()> { + /// let value = Value::from_array(((), vec![3.14_f32]))?; + /// + /// let extracted = value.try_extract_scalar::()?; + /// assert_eq!(extracted, 3.14); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// May return an error if: + /// - The tensor is not 0-dimensional. + /// - The provided type `T` does not match the tensor's element type. + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_tensor`] instead)* + /// - The tensor's data is not allocated in CPU memory. + pub fn try_extract_scalar(&self) -> Result { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if !dimensions.is_empty() { + return Err(Error::TensorNot0Dimensional(dimensions.len())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + Ok(unsafe { *output_array_ptr }) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } + } + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`]. @@ -101,36 +142,26 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn try_extract_tensor_mut(&mut self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + pub fn try_extract_tensor_mut(&mut self) -> Result> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok(extract_primitive_array_mut(shape, self.ptr())?) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + if ty == T::into_tensor_element_type() { + Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an @@ -159,40 +190,32 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_raw_tensor`] instead)* /// - The provided type `T` does not match the tensor's element type. - pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok((node_dims, unsafe { std::slice::from_raw_parts(output_array_ptr, len as _) })) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + let len = calculate_tensor_size(&dimensions); + Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a @@ -218,50 +241,41 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_raw_tensor_mut`] instead)* /// - The provided type `T` does not match the tensor's element type. - pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok((node_dims, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len as _) })) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + let len = calculate_tensor_size(&dimensions); + Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) })) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a Rust `ndarray`. /// /// ``` - /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; + /// # use ort::{Session, Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let array = ndarray::Array1::from_vec(vec!["hello", "world"]); - /// let tensor = DynTensor::from_string_array(&allocator, array.clone())?; + /// let tensor = Tensor::from_string_array(array.clone())?; /// /// let extracted = tensor.try_extract_string_tensor()?; /// assert_eq!(array.into_dyn(), extracted); @@ -271,78 +285,68 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == TensorElementType::String { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; + if ty == TensorElementType::String { + let len = calculate_tensor_size(&dimensions); - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); + // Total length of string data, not including \0 suffix + let mut total_length: ort_sys::size_t = 0; + ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; - let mut len: ort_sys::size_t = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length as _]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; (len + 1) as _]; - // Total length of string data, not including \0 suffix - let mut total_length: ort_sys::size_t = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length as _]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; (len + 1) as _]; + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0] as _..w[1] as _]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::StringFromUtf8Error)?; - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len as usize]); - offsets[len as usize] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0] as _..w[1] as _]; - String::from_utf8(slice.into()) + Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), strings) + .expect("Shape extracted from tensor didn't match tensor contents")) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: TensorElementType::String }) - .collect::, FromUtf8Error>>() - .map_err(Error::StringFromUtf8Error)?; - - Ok(ndarray::Array::from_shape_vec(shape, strings) - .expect("Shape extracted from tensor didn't match tensor contents") - .into_dyn()) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: TensorElementType::String - }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and /// an owned `Vec` of its data. /// /// ``` - /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; + /// # use ort::{Session, Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let array = vec!["hello", "world"]; - /// let tensor = DynTensor::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?; + /// let tensor = Tensor::from_string_array(([array.len()], array.clone().into_boxed_slice()))?; /// /// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?; /// assert_eq!(extracted_data, array); @@ -351,68 +355,57 @@ impl Value { /// # } /// ``` pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == TensorElementType::String { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; + if ty == TensorElementType::String { + let len = calculate_tensor_size(&dimensions); - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; + // Total length of string data, not including \0 suffix + let mut total_length: ort_sys::size_t = 0; + ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; - let mut output_array_ptr: *mut c_char = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut c_char = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length as _]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; (len + 1) as _]; - let mut len: ort_sys::size_t = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - // Total length of string data, not including \0 suffix - let mut total_length = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length as _]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; len as usize + 1]; + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length as _, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0] as _..w[1] as _]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::StringFromUtf8Error)?; - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len as usize]); - offsets[len as usize] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0] as _..w[1] as _]; - String::from_utf8(slice.into()) + Ok((dimensions, strings)) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: TensorElementType::String }) - .collect::, FromUtf8Error>>() - .map_err(Error::StringFromUtf8Error)?; - - Ok((node_dims, strings)) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: TensorElementType::String - }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Returns the shape of the tensor. @@ -445,7 +438,7 @@ impl Value { } } -impl Tensor { +impl Tensor { /// Extracts the underlying data into a read-only [`ndarray::ArrayView`]. /// /// ``` diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index d7f1db3..a4c7ba6 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -11,41 +11,96 @@ use std::{ use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker}; use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType}; -pub trait TensorValueTypeMarker: ValueTypeMarker {} +pub trait TensorValueTypeMarker: ValueTypeMarker { + crate::private_trait!(); +} #[derive(Debug)] pub struct DynTensorValueType; -impl ValueTypeMarker for DynTensorValueType {} -impl TensorValueTypeMarker for DynTensorValueType {} +impl ValueTypeMarker for DynTensorValueType { + crate::private_impl!(); +} +impl TensorValueTypeMarker for DynTensorValueType { + crate::private_impl!(); +} #[derive(Debug)] pub struct TensorValueType(PhantomData); -impl ValueTypeMarker for TensorValueType {} -impl TensorValueTypeMarker for TensorValueType {} +impl ValueTypeMarker for TensorValueType { + crate::private_impl!(); +} +impl TensorValueTypeMarker for TensorValueType { + crate::private_impl!(); +} +/// A tensor [`Value`] whose data type is unknown. pub type DynTensor = Value; +/// A strongly-typed tensor [`Value`]. pub type Tensor = Value>; +/// A reference to a tensor [`Value`] whose data type is unknown. pub type DynTensorRef<'v> = ValueRef<'v, DynTensorValueType>; +/// A mutable reference to a tensor [`Value`] whose data type is unknown. pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>; +/// A reference to a strongly-typed tensor [`Value`]. pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType>; +/// A mutable reference to a strongly-typed tensor [`Value`]. pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType>; impl DowncastableTarget for DynTensorValueType { fn can_downcast(dtype: &ValueType) -> bool { matches!(dtype, ValueType::Tensor { .. }) } + + crate::private_impl!(); } impl Value { /// Returns a mutable pointer to the tensor's data. + /// + /// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a + /// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be + /// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access. + /// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before + /// accessing it. + /// + /// ``` + /// # use ort::{Allocator, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let mut tensor = Tensor::::from_array((vec![5], vec![0, 1, 2, 3, 4]))?; + /// let ptr = tensor.data_ptr_mut()?.cast::(); + /// unsafe { + /// *ptr.add(3) = 42; + /// }; + /// + /// let (_, extracted) = tensor.extract_raw_tensor(); + /// assert_eq!(&extracted, &[0, 1, 2, 42, 4]); + /// # Ok(()) + /// # } + /// ``` pub fn data_ptr_mut(&mut self) -> Result<*mut ort_sys::c_void> { let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)]; Ok(buffer_ptr) } - /// Returns a pointer to the tensor's data. + /// Returns an immutable pointer to the tensor's underlying data. + /// + /// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a + /// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be + /// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access. + /// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before + /// accessing it. + /// + /// ``` + /// # use ort::{Allocator, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::from_array((vec![5], vec![0, 1, 2, 3, 4]))?; + /// let ptr = tensor.data_ptr()?.cast::(); + /// assert_eq!(unsafe { *ptr.add(3) }, 3); + /// # Ok(()) + /// # } + /// ``` pub fn data_ptr(&self) -> Result<*const ort_sys::c_void> { let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)]; @@ -53,6 +108,26 @@ impl Value { } /// Returns information about the device this tensor is allocated on. + /// + /// ``` + /// # use ort::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType, Session, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// // Tensors are allocated on CPU by default. + /// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CPU); + /// + /// # if false { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let cuda_allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// let tensor = Tensor::::new(&cuda_allocator, [1, 3, 224, 224])?; + /// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CUDA); + /// # } + /// # Ok(()) + /// # } + /// ``` pub fn memory_info(&self) -> Result { let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr) -> Error::GetTensorMemoryInfo; nonNull(memory_info_ptr)]; @@ -62,29 +137,68 @@ impl Value { impl Tensor { /// Converts from a strongly-typed [`Tensor`] to a type-erased [`DynTensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// let tensor_dyn = tensor.upcast(); + /// assert!(tensor_dyn.try_extract_raw_tensor::().is_ok()); + /// assert!(tensor_dyn.try_extract_raw_tensor::().is_err()); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast(self) -> DynTensor { unsafe { std::mem::transmute(self) } } - /// Converts from a strongly-typed [`Tensor`] to a reference to a type-erased [`DynTensor`]. + /// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// let tensor_dyn = tensor.upcast_ref(); + /// + /// let (_, original_extract) = tensor.extract_raw_tensor(); + /// let (_, ref_extract) = tensor_dyn.try_extract_raw_tensor::()?; + /// assert_eq!(original_extract, ref_extract); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast_ref(&self) -> DynTensorRef { DynTensorRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } /// Converts from a strongly-typed [`Tensor`] to a mutable reference to a type-erased [`DynTensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let mut tensor = Tensor::::from_array((vec![5], vec![1, 2, 3, 4, 5]))?; + /// let mut tensor_dyn = tensor.upcast_mut(); + /// + /// let (_, mut_view) = tensor_dyn.try_extract_raw_tensor_mut::()?; + /// mut_view[3] = 0; + /// + /// let (_, original_view) = tensor.extract_raw_tensor(); + /// assert_eq!(original_view, &[1, 2, 3, 0, 5]); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast_mut(&mut self) -> DynTensorRefMut { DynTensorRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -97,6 +211,8 @@ impl DowncastableTarget for TensorValueType _ => false } } + + crate::private_impl!(); } impl From>> for DynValue { @@ -113,6 +229,17 @@ impl From> for DynValue { impl Index<[i64; N]> for Tensor { type Output = T; fn index(&self, index: [i64; N]) -> &Self::Output { + // Interestingly, the `TensorAt` API doesn't check if the tensor is on CPU, so we have to perform the check ourselves. + if !self + .memory_info() + .expect("could not retrieve tensor memory info") + .allocation_device() + .expect("could not retrieve tensor allocation device") + .is_cpu_accessible() + { + panic!("Cannot directly index a tensor which is not allocated on the CPU."); + } + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")]; unsafe { &*out.cast::() } @@ -120,25 +247,46 @@ impl Index<[i64; N]> f } impl IndexMut<[i64; N]> for Tensor { fn index_mut(&mut self, index: [i64; N]) -> &mut Self::Output { + if !self + .memory_info() + .expect("could not retrieve tensor memory info") + .allocation_device() + .expect("could not retrieve tensor allocation device") + .is_cpu_accessible() + { + panic!("Cannot directly index a tensor which is not allocated on the CPU."); + } + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")]; unsafe { &mut *out.cast::() } } } +pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize { + let mut size = 1usize; + for dim in shape { + if *dim < 0 { + return 0; + } + size *= *dim as usize; + } + size +} + #[cfg(test)] mod tests { use std::sync::Arc; use ndarray::{ArcArray1, Array1, CowArray}; - use crate::{Allocator, DynTensor, TensorElementType, Value, ValueType}; + use crate::{Tensor, TensorElementType, ValueType}; #[test] #[cfg(feature = "ndarray")] fn test_tensor_value() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; - let value = Value::from_array(Array1::from_vec(v.clone()))?; + let value = Tensor::from_array(Array1::from_vec(v.clone()))?; assert!(value.is_tensor()?); assert_eq!(value.dtype()?.tensor_type(), Some(TensorElementType::Float32)); assert_eq!( @@ -163,17 +311,17 @@ mod tests { let arc1 = ArcArray1::from_vec(v.clone()); let mut arc2 = ArcArray1::clone(&arc1); - let value = Value::from_array(&mut arc2)?; + let value = Tensor::from_array(&mut arc2)?; drop((arc1, arc2)); assert_eq!(value.extract_raw_tensor().1, &v); let cow = CowArray::from(Array1::from_vec(v.clone())); - let value = Value::from_array(&cow)?; + let value = Tensor::from_array(&cow)?; assert_eq!(value.extract_raw_tensor().1, &v); let owned = Array1::from_vec(v.clone()); - let value = Value::from_array(owned.view())?; + let value = Tensor::from_array(owned.view())?; drop(owned); assert_eq!(value.extract_raw_tensor().1, &v); @@ -186,7 +334,7 @@ mod tests { let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; - let value = Value::from_array((shape, Arc::clone(&arc)))?; + let value = Tensor::from_array((shape, Arc::clone(&arc)))?; drop(arc); assert_eq!(value.try_extract_raw_tensor::()?.1, &v); @@ -196,10 +344,9 @@ mod tests { #[test] #[cfg(feature = "ndarray")] fn test_string_tensor_ndarray() -> crate::Result<()> { - let allocator = Allocator::default(); let v = Array1::from_vec(vec!["hello world".to_string(), "こんにちは世界".to_string()]); - let value = DynTensor::from_string_array(&allocator, v.view())?; + let value = Tensor::from_string_array(v.view())?; let extracted = value.try_extract_string_tensor()?; assert_eq!(extracted, v.into_dyn()); @@ -208,10 +355,9 @@ mod tests { #[test] fn test_string_tensor_raw() -> crate::Result<()> { - let allocator = Allocator::default(); let v = vec!["hello world".to_string(), "こんにちは世界".to_string()]; - let value = DynTensor::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?; + let value = Tensor::from_string_array((vec![v.len() as i64], v.clone().into_boxed_slice()))?; let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?; assert_eq!(extracted_shape, [v.len() as i64]); assert_eq!(extracted_view, v); @@ -224,10 +370,10 @@ mod tests { let v: Vec = vec![1., 2., 3., 4., 5.]; let shape = [v.len()]; - let value_arc_box = Value::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; - let value_box = Value::from_array((shape, v.clone().into_boxed_slice()))?; - let value_vec = Value::from_array((shape, v.clone()))?; - let value_slice = Value::from_array((shape, &v[..]))?; + let value_arc_box = Tensor::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; + let value_box = Tensor::from_array((shape, v.clone().into_boxed_slice()))?; + let value_vec = Tensor::from_array((shape, v.clone()))?; + let value_slice = Tensor::from_array((shape, &v[..]))?; assert_eq!(value_arc_box.extract_raw_tensor().1, &v); assert_eq!(value_box.extract_raw_tensor().1, &v); diff --git a/src/value/mod.rs b/src/value/mod.rs index 868403f..cf50600 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -166,6 +166,14 @@ pub(crate) enum ValueInner { } } +impl ValueInner { + pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue { + match self { + ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() + } + } +} + /// A temporary version of a [`Value`] with a lifetime specifier. #[derive(Debug)] pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { @@ -277,8 +285,8 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> { /// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`] #[derive(Debug)] pub struct Value { - inner: ValueInner, - _markers: PhantomData + pub(crate) inner: Arc, + pub(crate) _markers: PhantomData } /// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`]. @@ -291,11 +299,15 @@ pub type DynValue = Value; /// /// For example, [`Tensor::try_extract_tensor`] can only be used on [`Value`]s with the [`TensorValueTypeMarker`] (which /// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s. -pub trait ValueTypeMarker: Debug {} +pub trait ValueTypeMarker: Debug { + crate::private_trait!(); +} /// Represents a type that a [`DynValue`] can be downcast to. pub trait DowncastableTarget: ValueTypeMarker { fn can_downcast(dtype: &ValueType) -> bool; + + crate::private_trait!(); } // this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence` @@ -303,15 +315,25 @@ impl DowncastableTarget for DynValueTypeMarker { fn can_downcast(_: &ValueType) -> bool { true } + + crate::private_impl!(); } /// The dynamic type marker, used for values which can be of any type. #[derive(Debug)] pub struct DynValueTypeMarker; -impl ValueTypeMarker for DynValueTypeMarker {} -impl MapValueTypeMarker for DynValueTypeMarker {} -impl SequenceValueTypeMarker for DynValueTypeMarker {} -impl TensorValueTypeMarker for DynValueTypeMarker {} +impl ValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl MapValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl SequenceValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl TensorValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} unsafe impl Send for Value {} @@ -350,7 +372,7 @@ impl Value { /// /// If the value belongs to a session (i.e. if it is returned from [`crate::Session::run`] or /// [`crate::IoBinding::run`]), you must provide the [`SharedSessionInner`] (acquired from - /// [`crate::Session::inner`]). This ensures the session is not dropped until the value is. + /// [`crate::Session::inner`]). This ensures the session is not dropped until any values owned by it is. /// /// # Safety /// @@ -359,7 +381,7 @@ impl Value { #[must_use] pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: true, _session: session }, + inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }), _markers: PhantomData } } @@ -369,16 +391,14 @@ impl Value { #[must_use] pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: false, _session: session }, + inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }), _markers: PhantomData } } /// Returns the underlying [`ort_sys::OrtValue`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtValue { - match &self.inner { - ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() - } + self.inner.ptr() } /// Create a view of this value's data. @@ -386,7 +406,7 @@ impl Value { ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -396,7 +416,7 @@ impl Value { ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -442,7 +462,7 @@ impl Value { Ok(ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) })) } else { @@ -450,7 +470,7 @@ impl Value { } } - /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed + /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed /// mutable-reference variant, like [`TensorRefMut`]. #[inline] pub fn downcast_mut(&mut self) -> Result> { @@ -459,7 +479,7 @@ impl Value { Ok(ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) })) } else { @@ -468,17 +488,17 @@ impl Value { } } -impl Drop for Value { +impl Drop for ValueInner { fn drop(&mut self) { let ptr = self.ptr(); tracing::trace!( "dropping {} value at {ptr:p}", - match &self.inner { + match self { ValueInner::RustOwned { .. } => "rust-owned", ValueInner::CppOwned { .. } => "cpp-owned" } ); - if !matches!(&self.inner, ValueInner::CppOwned { drop: false, .. }) { + if !matches!(self, ValueInner::CppOwned { drop: false, .. }) { ortsys![unsafe ReleaseValue(ptr)]; } } diff --git a/src/wasm.rs b/src/wasm.rs index 51cf82f..235a219 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -1,6 +1,6 @@ //! Utilities for using `ort` in WebAssembly. //! -//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs: +//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs in WASM: //! ``` //! # use ort::Session; //! # static MODEL_BYTES: &[u8] = include_bytes!("../tests/data/upsample.ort"); @@ -223,12 +223,12 @@ mod emscripten_shims { #[no_mangle] #[export_name = "_initialize"] pub fn initialize() { - // No idea what the hell this does, but the presence of an `_initialize` function prevents the linker from calling - // `__wasm_call_ctors` at the top of every function - including the functions `wasm-bindgen` interprets to generate - // JS glue code. The `__wasm_call_ctors` call was calling complex functions that the interpreter isn't equipped to - // handle, which was preventing wbg from outputting anything. I don't know what specific constructors this is calling, - // and most basic ONNX Runtime APIs *do* work without calling this, but we encourage the user to perform this - // initialization at program start anyways to be safe. + // The presence of an `_initialize` function prevents the linker from calling `__wasm_call_ctors` at the top of every + // function - including the functions `wasm-bindgen` interprets to generate JS glue code. `__wasm_call_ctors` calls + // complex functions that wbg's interpreter isn't equipped to handle, which was preventing wbg from outputting + // anything. + // I'm not entirely sure what `__wasm_call_ctors` is initializing, but it seems to have something to do with C++ + // vtables, and it's crucial for proper operation. extern "C" { fn __wasm_call_ctors(); } diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index bef1a57..f3af20a 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -3,7 +3,7 @@ use std::path::Path; use ndarray::{ArrayD, IxDyn}; -use ort::{inputs, DynTensor, GraphOptimizationLevel, Session}; +use ort::{inputs, GraphOptimizationLevel, Session, Tensor}; use test_log::test; #[test] @@ -22,7 +22,7 @@ fn vectorizer() -> ort::Result<()> { let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()); // Just one input - let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?; + let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?; // Perform the inference let outputs = session.run(input_tensor_values)?;