mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
docs: explain up/downcasting
This commit is contained in:
@@ -25,6 +25,9 @@
|
||||
"setup": {
|
||||
"title": "Setup"
|
||||
},
|
||||
"fundamentals": {
|
||||
"title": "Fundamentals"
|
||||
},
|
||||
"perf": {
|
||||
"title": "Performance"
|
||||
},
|
||||
|
||||
104
docs/pages/fundamentals/value.mdx
Normal file
104
docs/pages/fundamentals/value.mdx
Normal file
@@ -0,0 +1,104 @@
|
||||
---
|
||||
title: 'Value'
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components';
|
||||
|
||||
For ONNX Runtime, a **value** represents any type that can be given to/returned from a session or operator. Values come in three main types:
|
||||
- **Tensors** (multi-dimensional arrays). This is the most common type of `Value`.
|
||||
- **Maps** map a key type to a value type, similar to Rust's `HashMap<K, V>`.
|
||||
- **Sequences** are homogenously-typed dynamically-sized lists, similar to Rust's `Vec<T>`. The only values allowed in sequences are tensors, or maps of tensors.
|
||||
|
||||
In order to actually use the data in these containers, you can use the `.try_extract_*` methods. `try_extract_tensor(_mut)` extracts an `ndarray::ArrayView(Mut)` from the value if it is a tensor. `try_extract_sequence` returns a `Vec` of values, and `try_extract_map` returns a `HashMap`.
|
||||
|
||||
Sessions in `ort` return a map of `DynValue`s. You can determine a value's type via its `.dtype()` method. You can also use fallible methods to extract data from this value - for example, [`DynValue::try_extract_tensor`](https://ort.pyke.io/rustdoc/ort/type.DynValue.html#method.try_extract_tensor), which fails if the value is not a tensor. Often times though, you'll want to reuse the same value which you are certain is a tensor - in which case, you can **downcast** the value.
|
||||
|
||||
## Downcasting
|
||||
**Downcasting** means to convert a `Dyn` type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`:
|
||||
```rs
|
||||
let value: ort::DynValue = outputs.remove("output0").unwrap();
|
||||
|
||||
let dyn_tensor: ort::DynTensor = value.downcast()?;
|
||||
```
|
||||
|
||||
If `value` is not actually a tensor, the `downcast()` call will fail.
|
||||
|
||||
`DynTensor` allows you to use
|
||||
|
||||
### Stronger types
|
||||
`DynTensor` means that the type **is** a tensor, but the *element type is unknown*. There are also `DynSequence`s and `DynMap`s, which have the same meaning - the element/key/value types are unknown.
|
||||
|
||||
The strongly typed variants of these types - `Tensor<T>`, `Sequence<T>`, and `Map<K, V>`, can be directly downcasted to, too:
|
||||
```rs
|
||||
let dyn_value: ort::DynValue = outputs.remove("output0").unwrap();
|
||||
|
||||
let f32_tensor: ort::Tensor<f32> = dyn_value.downcast()?;
|
||||
```
|
||||
|
||||
If `value` is not a tensor, **or** if the element type of the value does not match what was requested (`f32`), the `downcast()` call will fail.
|
||||
|
||||
Stronger typed values have infallible variants of the `.try_extract_*` methods:
|
||||
```rs
|
||||
// We could try to extract a tensor directly from a `DynValue`...
|
||||
let f32_array: ArrayViewD<f32> = dyn_value.try_extract_tensor()?;
|
||||
|
||||
// Or, we can first onvert it to a tensor, and then extract afterwards:
|
||||
let tensor: ort::Tensor<f32> = dyn_value.downcast()?;
|
||||
let f32_array = tensor.extract_tensor(); // no `?` required, this will never fail!
|
||||
```
|
||||
|
||||
## Upcasting
|
||||
**Upcasting** means to convert a strongly-typed value type like `Tensor<f32>` to a weaker type like `DynTensor` or `DynValue`. This can be useful if you have code that stores values of different types, e.g. in a `HashMap<String, DynValue>`.
|
||||
|
||||
Strongly-typed value types like `Tensor<f32>` can be converted into a `DynTensor` using `.upcast()`:
|
||||
```rs
|
||||
let dyn_tensor = f32_tensor.upcast();
|
||||
// type is DynTensor
|
||||
```
|
||||
|
||||
`Tensor<f32>` or `DynTensor` can be cast to a `DynValue` by using `.into_dyn()`:
|
||||
```rs
|
||||
let dyn_value = f32_tensor.into_dyn();
|
||||
// type is DynValue
|
||||
```
|
||||
|
||||
Upcasting a value doesn't change its underlying type; it just removes the specialization. You cannot, for example, upcast a `Tensor<f32>` to a `DynValue` and then downcast it to a `Sequence`; it's still a `Tensor<f32>`, just contained in a different type.
|
||||
|
||||
## Conversion recap
|
||||
- `DynValue` represents a value that can be any type - tensor, sequence, or map. The type can be retrieved with `.dtype()`.
|
||||
- `DynTensor`, `DynMap`, and `DynSequence` are values with known container types, but unknown element types.
|
||||
- `Tensor<T>`, `Map<K, V>`, and `Sequence<T>` are values with known container and element types.
|
||||
- `Tensor<T>` and co. can be converted from/to their dyn types using `.downcast()`/`.upcast()`, respectively.
|
||||
- `Tensor<T>`/`DynTensor` and co. can be converted to `DynValue`s using `.into_dyn()`.
|
||||
|
||||
<img width="100%" src="/assets/casting-map.png" alt="An illustration of the relationship between value types as described above, used for visualization purposes." />
|
||||
|
||||
<Callout type='info'>
|
||||
Note that `DynTensor` cannot be downcast to `Tensor<T>`, but `DynTensor` can be upcast to `DynValue` with `.into_dyn()`, and then downcast to `Tensor<T>` with `.downcast()`.
|
||||
|
||||
Downcasts are cheap, as they only check the value's type. Upcasts compile to a no-op.
|
||||
</Callout>
|
||||
|
||||
## Views
|
||||
A view (also called a ref) is functionally a borrowed variant of a value. There are also mutable views, which are equivalent to mutably borrowed values. Views are represented as separate structs so that they can be down/upcasted.
|
||||
|
||||
View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respectively:
|
||||
- Tensors have `DynTensorRef(Mut)` and `TensorRef(Mut)`.
|
||||
- Maps have `DynMapRef(Mut)` and `MapRef(Mut)`.
|
||||
- Sequences have `DynSequenceRef(Mut)` and `SequenceRef(Mut)`.
|
||||
|
||||
These views can be acquired with `.view()` or `.view_mut()` on a value type:
|
||||
```rs
|
||||
let my_tensor: ort::Tensor<f32> = Tensor::new(...)?;
|
||||
|
||||
let tensor_view: ort::TensorRef<'_, f32> = my_tensor.view();
|
||||
```
|
||||
|
||||
Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment).
|
||||
|
||||
You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`:
|
||||
```rs
|
||||
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.downcast_ref()?;
|
||||
// is equivalent to
|
||||
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.view().downcast()?;
|
||||
```
|
||||
BIN
docs/public/assets/casting-map.png
Normal file
BIN
docs/public/assets/casting-map.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
Reference in New Issue
Block a user