examples: add all-mini-lm-l6 example (#243)

This commit is contained in:
Nigel Christian
2024-07-21 16:47:24 -04:00
committed by GitHub
parent 9282c3b3fd
commit 1a10b11b10
5 changed files with 64 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
[package]
publish = false
name = "example-all-mini-lm-l6"
version = "0.0.0"
edition = "2021"
[dependencies]
ort = { path = "../../", features = [ "fetch-models" ] }
ndarray = "0.15"
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
[features]
load-dynamic = [ "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]

View File

@@ -0,0 +1,5 @@
fn main() {
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
#[cfg(target_os = "macos")]
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,41 @@
use std::path::Path;
use ndarray::{Array1, Axis};
use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session};
use tokenizers::Tokenizer;
/// all-mini-lm-l6 embeddings generation
///
/// This is a sentence-transformers model: It maps sentences & paragraphs to a 384
///
/// dimensional dense vector space and can be used for tasks like clustering or semantic search.
fn main() -> ort::Result<()> {
// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::fmt::init();
// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
ort::init()
.with_name("all-Mini-LM-L6")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;
// Load our model
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_url("https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx")?;
// Load the tokenizer and encode the text.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();
let tokens = tokenizer.encode("test", false)?;
let mask = tokens.get_attention_mask().iter().map(|i| *i as i64).collect::<Vec<i64>>();
let ids = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<i64>>();
let a_ids = Array1::from_vec(ids);
let a_mask = Array1::from_vec(mask);
let input_ids = a_ids.view().insert_axis(Axis(0));
let input_mask = a_mask.view().insert_axis(Axis(0));
let outputs = session.run(ort::inputs![input_ids, input_mask]?)?;
let tensor = outputs[1].try_extract_tensor::<f32>();
println!("{:?}", tensor);
Ok(())
}