mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor!: rename extract_tensor -> extract_array, extract_raw_tensor -> extract_tensor
...and `extract_raw_map` to `extract_key_values`, but nobody cares about that
This commit is contained in:
@@ -52,7 +52,7 @@ fn mnist_5() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
|
||||
|
||||
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
||||
.try_extract_tensor()?
|
||||
.try_extract_array()?
|
||||
.softmax(ndarray::Axis(1))
|
||||
.iter()
|
||||
.copied()
|
||||
|
||||
@@ -78,7 +78,7 @@ fn squeezenet_mushroom() -> ort::Result<()> {
|
||||
// Downloaded model does not have a softmax as final layer; call softmax on second axis
|
||||
// and iterate on resulting probabilities, creating an index to later access labels.
|
||||
let mut probabilities: Vec<(usize, f32)> = outputs[0]
|
||||
.try_extract_tensor()?
|
||||
.try_extract_array()?
|
||||
.softmax(ndarray::Axis(1))
|
||||
.iter()
|
||||
.copied()
|
||||
|
||||
@@ -76,7 +76,7 @@ fn upsample() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
|
||||
|
||||
assert_eq!(outputs.len(), 1);
|
||||
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
|
||||
let output: ArrayViewD<f32> = outputs[0].try_extract_array()?;
|
||||
|
||||
// The image should have doubled in size
|
||||
assert_eq!(output.shape(), [1, 448, 448, 3]);
|
||||
@@ -113,7 +113,7 @@ fn upsample_with_ort_model() -> ort::Result<()> {
|
||||
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;
|
||||
|
||||
assert_eq!(outputs.len(), 1);
|
||||
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;
|
||||
let output: ArrayViewD<f32> = outputs[0].try_extract_array()?;
|
||||
|
||||
// The image should have doubled in size
|
||||
assert_eq!(output.shape(), [1, 448, 448, 3]);
|
||||
|
||||
@@ -28,7 +28,7 @@ fn vectorizer() -> ort::Result<()> {
|
||||
let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());
|
||||
|
||||
let outputs = session.run(inputs![Tensor::from_string_array(&array)?])?;
|
||||
assert_eq!(outputs[0].try_extract_tensor::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
|
||||
assert_eq!(outputs[0].try_extract_array::<f32>()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user