mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
examples(phi-3-vision): Simplify input processing with KV cache (#296)
With KV cache, full input sequence reconstruction is unnecessary. Only process the newly generated token for each iteration.
This commit is contained in:
@@ -35,8 +35,8 @@ This example currently only supports single image input.
|
||||
The performance of ONNX-based LLM inference can be relatively slow, especially on CPU:
|
||||
|
||||
- On an Apple M1 Pro:
|
||||
- For image+text input (about 300 tokens): ~5 seconds per output token
|
||||
- For text-only input (about 10 tokens): ~200ms per output token
|
||||
- For image+text input (about 300 tokens): ~7 tokens/s
|
||||
- For text-only input (about 10 tokens): ~5 tokens/s
|
||||
|
||||
## Run this Example
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ const VISION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-vision.onnx";
|
||||
const TEXT_EMBEDDING_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text-embedding.onnx";
|
||||
const GENERATION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text.onnx";
|
||||
|
||||
const MAX_LENGTH: usize = 100; // max length of the generated text
|
||||
const MAX_LENGTH: usize = 1000; // max length of the generated text
|
||||
const EOS_TOKEN_ID: i64 = 32007; // <|end|>
|
||||
const USER_TOKEN_ID: i64 = 32010; // <|user|>
|
||||
const VOCAB_SIZE: usize = 32064;
|
||||
@@ -108,7 +108,7 @@ pub async fn generate_text(
|
||||
image: &Option<DynamicImage>,
|
||||
text: &str
|
||||
) -> Result<()> {
|
||||
let (mut inputs_embeds, mut attention_mask) = {
|
||||
let (inputs_embeds, mut attention_mask) = {
|
||||
let visual_features = get_image_embedding(&vision_model, &image)?;
|
||||
let prompt = format_chat_template(&image, text);
|
||||
let encoding = tokenizer.encode(prompt, true).map_err(|e| anyhow::anyhow!("Error encoding: {:?}", e))?;
|
||||
@@ -139,12 +139,13 @@ pub async fn generate_text(
|
||||
// 4. Head size (96)
|
||||
let mut past_key_values: Vec<Array4<f32>> = vec![Array4::zeros((1, 32, 0, 96)); 64];
|
||||
let mut generated_tokens: Vec<i64> = Vec::new();
|
||||
let mut next_inputs_embeds = inputs_embeds.clone();
|
||||
// Loop until <|end|> token is generated or max length is reached
|
||||
for _ in 0..MAX_LENGTH {
|
||||
// Prepare model inputs
|
||||
let model_inputs = {
|
||||
let mut model_inputs = ort::inputs![
|
||||
"inputs_embeds" => inputs_embeds.clone(),
|
||||
"inputs_embeds" => next_inputs_embeds.clone(),
|
||||
"attention_mask" => attention_mask.clone(),
|
||||
]?;
|
||||
for i in 0..32 {
|
||||
@@ -176,27 +177,21 @@ pub async fn generate_text(
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.0 as i64;
|
||||
|
||||
if next_token_id == EOS_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
|
||||
generated_tokens.push(next_token_id);
|
||||
// Log the generated text
|
||||
let output_ids: Vec<u32> = generated_tokens.iter().map(|&id| id as u32).collect();
|
||||
let generated_text = tokenizer.decode(&output_ids, false).unwrap();
|
||||
tracing::info!("Generated text: {}", generated_text);
|
||||
|
||||
if next_token_id == EOS_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
|
||||
// Update inputs_embeds, attention_mask, and past_key_values for the next iteration
|
||||
(inputs_embeds, attention_mask) = {
|
||||
let new_token_id = Array2::from_elem((1, 1), next_token_id);
|
||||
let new_token_embed = get_text_embedding(&text_embedding_model, &new_token_id)?;
|
||||
// Merge the new token embedding with the previous embeddings
|
||||
let mut combined_embeds = Array3::zeros((inputs_embeds.shape()[0], inputs_embeds.shape()[1] + 1, inputs_embeds.shape()[2]));
|
||||
combined_embeds.slice_mut(s![.., ..inputs_embeds.shape()[1], ..]).assign(&inputs_embeds);
|
||||
combined_embeds.slice_mut(s![.., inputs_embeds.shape()[1].., ..]).assign(&new_token_embed);
|
||||
let new_attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
|
||||
(combined_embeds, new_attention_mask)
|
||||
};
|
||||
// Update current_embeds, attention_mask, and past_key_values for the next iteration
|
||||
let new_token_id = Array2::from_elem((1, 1), next_token_id);
|
||||
next_inputs_embeds = get_text_embedding(&text_embedding_model, &new_token_id)?;
|
||||
attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
|
||||
for i in 0..32 {
|
||||
past_key_values[i * 2] = model_outputs[format!("present.{}.key", i)]
|
||||
.try_extract_tensor::<f32>()?
|
||||
|
||||
Reference in New Issue
Block a user