diff --git a/docs/pages/_meta.json b/docs/pages/_meta.json index 7eb8b61..5cc642d 100644 --- a/docs/pages/_meta.json +++ b/docs/pages/_meta.json @@ -25,6 +25,9 @@ "setup": { "title": "Setup" }, + "fundamentals": { + "title": "Fundamentals" + }, "perf": { "title": "Performance" }, diff --git a/docs/pages/fundamentals/value.mdx b/docs/pages/fundamentals/value.mdx new file mode 100644 index 0000000..1f4812e --- /dev/null +++ b/docs/pages/fundamentals/value.mdx @@ -0,0 +1,104 @@ +--- +title: 'Value' +--- + +import { Callout } from 'nextra/components'; + +For ONNX Runtime, a **value** represents any type that can be given to/returned from a session or operator. Values come in three main types: +- **Tensors** (multi-dimensional arrays). This is the most common type of `Value`. +- **Maps** map a key type to a value type, similar to Rust's `HashMap`. +- **Sequences** are homogenously-typed dynamically-sized lists, similar to Rust's `Vec`. The only values allowed in sequences are tensors, or maps of tensors. + +In order to actually use the data in these containers, you can use the `.try_extract_*` methods. `try_extract_tensor(_mut)` extracts an `ndarray::ArrayView(Mut)` from the value if it is a tensor. `try_extract_sequence` returns a `Vec` of values, and `try_extract_map` returns a `HashMap`. + +Sessions in `ort` return a map of `DynValue`s. You can determine a value's type via its `.dtype()` method. You can also use fallible methods to extract data from this value - for example, [`DynValue::try_extract_tensor`](https://ort.pyke.io/rustdoc/ort/type.DynValue.html#method.try_extract_tensor), which fails if the value is not a tensor. Often times though, you'll want to reuse the same value which you are certain is a tensor - in which case, you can **downcast** the value. + +## Downcasting +**Downcasting** means to convert a `Dyn` type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`: +```rs +let value: ort::DynValue = outputs.remove("output0").unwrap(); + +let dyn_tensor: ort::DynTensor = value.downcast()?; +``` + +If `value` is not actually a tensor, the `downcast()` call will fail. + +`DynTensor` allows you to use + +### Stronger types +`DynTensor` means that the type **is** a tensor, but the *element type is unknown*. There are also `DynSequence`s and `DynMap`s, which have the same meaning - the element/key/value types are unknown. + +The strongly typed variants of these types - `Tensor`, `Sequence`, and `Map`, can be directly downcasted to, too: +```rs +let dyn_value: ort::DynValue = outputs.remove("output0").unwrap(); + +let f32_tensor: ort::Tensor = dyn_value.downcast()?; +``` + +If `value` is not a tensor, **or** if the element type of the value does not match what was requested (`f32`), the `downcast()` call will fail. + +Stronger typed values have infallible variants of the `.try_extract_*` methods: +```rs +// We could try to extract a tensor directly from a `DynValue`... +let f32_array: ArrayViewD = dyn_value.try_extract_tensor()?; + +// Or, we can first onvert it to a tensor, and then extract afterwards: +let tensor: ort::Tensor = dyn_value.downcast()?; +let f32_array = tensor.extract_tensor(); // no `?` required, this will never fail! +``` + +## Upcasting +**Upcasting** means to convert a strongly-typed value type like `Tensor` to a weaker type like `DynTensor` or `DynValue`. This can be useful if you have code that stores values of different types, e.g. in a `HashMap`. + +Strongly-typed value types like `Tensor` can be converted into a `DynTensor` using `.upcast()`: +```rs +let dyn_tensor = f32_tensor.upcast(); +// type is DynTensor +``` + +`Tensor` or `DynTensor` can be cast to a `DynValue` by using `.into_dyn()`: +```rs +let dyn_value = f32_tensor.into_dyn(); +// type is DynValue +``` + +Upcasting a value doesn't change its underlying type; it just removes the specialization. You cannot, for example, upcast a `Tensor` to a `DynValue` and then downcast it to a `Sequence`; it's still a `Tensor`, just contained in a different type. + +## Conversion recap +- `DynValue` represents a value that can be any type - tensor, sequence, or map. The type can be retrieved with `.dtype()`. +- `DynTensor`, `DynMap`, and `DynSequence` are values with known container types, but unknown element types. +- `Tensor`, `Map`, and `Sequence` are values with known container and element types. +- `Tensor` and co. can be converted from/to their dyn types using `.downcast()`/`.upcast()`, respectively. +- `Tensor`/`DynTensor` and co. can be converted to `DynValue`s using `.into_dyn()`. + +An illustration of the relationship between value types as described above, used for visualization purposes. + + + Note that `DynTensor` cannot be downcast to `Tensor`, but `DynTensor` can be upcast to `DynValue` with `.into_dyn()`, and then downcast to `Tensor` with `.downcast()`. + + Downcasts are cheap, as they only check the value's type. Upcasts compile to a no-op. + + +## Views +A view (also called a ref) is functionally a borrowed variant of a value. There are also mutable views, which are equivalent to mutably borrowed values. Views are represented as separate structs so that they can be down/upcasted. + +View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respectively: +- Tensors have `DynTensorRef(Mut)` and `TensorRef(Mut)`. +- Maps have `DynMapRef(Mut)` and `MapRef(Mut)`. +- Sequences have `DynSequenceRef(Mut)` and `SequenceRef(Mut)`. + +These views can be acquired with `.view()` or `.view_mut()` on a value type: +```rs +let my_tensor: ort::Tensor = Tensor::new(...)?; + +let tensor_view: ort::TensorRef<'_, f32> = my_tensor.view(); +``` + +Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment). + +You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`: +```rs +let tensor_view: ort::TensorRef<'_, f32> = dyn_value.downcast_ref()?; +// is equivalent to +let tensor_view: ort::TensorRef<'_, f32> = dyn_value.view().downcast()?; +``` diff --git a/docs/public/assets/casting-map.png b/docs/public/assets/casting-map.png new file mode 100644 index 0000000..c7e2b06 Binary files /dev/null and b/docs/public/assets/casting-map.png differ