examples(phi-3-vision): Simplify input processing with KV cache (#296)

With KV cache, full input sequence reconstruction is unnecessary.
Only process the newly generated token for each iteration.
This commit is contained in:
web3nomad.eth
2024-10-16 00:02:09 +08:00
committed by GitHub
parent cdd6be7a66
commit 87dc4f21fa
2 changed files with 15 additions and 20 deletions

View File

@@ -35,8 +35,8 @@ This example currently only supports single image input.
The performance of ONNX-based LLM inference can be relatively slow, especially on CPU:
- On an Apple M1 Pro:
- For image+text input (about 300 tokens): ~5 seconds per output token
- For text-only input (about 10 tokens): ~200ms per output token
- For image+text input (about 300 tokens): ~7 tokens/s
- For text-only input (about 10 tokens): ~5 tokens/s
## Run this Example

View File

@@ -11,7 +11,7 @@ const VISION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-vision.onnx";
const TEXT_EMBEDDING_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text-embedding.onnx";
const GENERATION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text.onnx";
const MAX_LENGTH: usize = 100; // max length of the generated text
const MAX_LENGTH: usize = 1000; // max length of the generated text
const EOS_TOKEN_ID: i64 = 32007; // <|end|>
const USER_TOKEN_ID: i64 = 32010; // <|user|>
const VOCAB_SIZE: usize = 32064;
@@ -108,7 +108,7 @@ pub async fn generate_text(
image: &Option<DynamicImage>,
text: &str
) -> Result<()> {
let (mut inputs_embeds, mut attention_mask) = {
let (inputs_embeds, mut attention_mask) = {
let visual_features = get_image_embedding(&vision_model, &image)?;
let prompt = format_chat_template(&image, text);
let encoding = tokenizer.encode(prompt, true).map_err(|e| anyhow::anyhow!("Error encoding: {:?}", e))?;
@@ -139,12 +139,13 @@ pub async fn generate_text(
// 4. Head size (96)
let mut past_key_values: Vec<Array4<f32>> = vec![Array4::zeros((1, 32, 0, 96)); 64];
let mut generated_tokens: Vec<i64> = Vec::new();
let mut next_inputs_embeds = inputs_embeds.clone();
// Loop until <|end|> token is generated or max length is reached
for _ in 0..MAX_LENGTH {
// Prepare model inputs
let model_inputs = {
let mut model_inputs = ort::inputs![
"inputs_embeds" => inputs_embeds.clone(),
"inputs_embeds" => next_inputs_embeds.clone(),
"attention_mask" => attention_mask.clone(),
]?;
for i in 0..32 {
@@ -176,27 +177,21 @@ pub async fn generate_text(
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as i64;
if next_token_id == EOS_TOKEN_ID {
break;
}
generated_tokens.push(next_token_id);
// Log the generated text
let output_ids: Vec<u32> = generated_tokens.iter().map(|&id| id as u32).collect();
let generated_text = tokenizer.decode(&output_ids, false).unwrap();
tracing::info!("Generated text: {}", generated_text);
if next_token_id == EOS_TOKEN_ID {
break;
}
// Update inputs_embeds, attention_mask, and past_key_values for the next iteration
(inputs_embeds, attention_mask) = {
let new_token_id = Array2::from_elem((1, 1), next_token_id);
let new_token_embed = get_text_embedding(&text_embedding_model, &new_token_id)?;
// Merge the new token embedding with the previous embeddings
let mut combined_embeds = Array3::zeros((inputs_embeds.shape()[0], inputs_embeds.shape()[1] + 1, inputs_embeds.shape()[2]));
combined_embeds.slice_mut(s![.., ..inputs_embeds.shape()[1], ..]).assign(&inputs_embeds);
combined_embeds.slice_mut(s![.., inputs_embeds.shape()[1].., ..]).assign(&new_token_embed);
let new_attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
(combined_embeds, new_attention_mask)
};
// Update current_embeds, attention_mask, and past_key_values for the next iteration
let new_token_id = Array2::from_elem((1, 1), next_token_id);
next_inputs_embeds = get_text_embedding(&text_embedding_model, &new_token_id)?;
attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
for i in 0..32 {
past_key_values[i * 2] = model_outputs[format!("present.{}.key", i)]
.try_extract_tensor::<f32>()?