mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
examples: image chat generation with phi-3 vision (#291)
This commit is contained in:
@@ -6,6 +6,7 @@ members = [
|
||||
'examples/gpt2',
|
||||
'examples/model-info',
|
||||
'examples/yolov8',
|
||||
'examples/phi-3-vision',
|
||||
'examples/modnet',
|
||||
'examples/sentence-transformers',
|
||||
'examples/training'
|
||||
|
||||
22
examples/phi-3-vision/Cargo.toml
Normal file
22
examples/phi-3-vision/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
publish = false
|
||||
name = "phi-3-vision"
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.89"
|
||||
image = "0.25"
|
||||
ndarray = "0.16"
|
||||
ort = { path = "../../", features = ["fetch-models"] }
|
||||
tokio = { version = "1.40.0", features = ["full"] }
|
||||
tokenizers = "0.20"
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = [
|
||||
"env-filter",
|
||||
"fmt",
|
||||
] }
|
||||
tracing = "0.1"
|
||||
|
||||
[features]
|
||||
load-dynamic = ["ort/load-dynamic"]
|
||||
cuda = ["ort/cuda"]
|
||||
56
examples/phi-3-vision/README.md
Normal file
56
examples/phi-3-vision/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Phi-3 Vision ONNX Example
|
||||
|
||||
This example demonstrates the usage of Microsoft's [Phi-3 Vision model](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu)
|
||||
|
||||
Phi-3 Vision ONNX is a multimodal model that combines vision and language processing. It uses three interconnected ONNX models:
|
||||
|
||||
- Vision model: Processes images to extract visual features
|
||||
- Text embedding model: Embeds input text into a format compatible with the model
|
||||
- Text generation model: Produces text outputs based on the combined visual and textual inputs
|
||||
|
||||
This multi-model structure requires a coordinated process:
|
||||
|
||||
1. Image Processing:
|
||||
- Preprocess the input image
|
||||
- Pass it through the vision ONNX model for visual features
|
||||
|
||||
2. Text Embedding:
|
||||
- Tokenize input text
|
||||
- Process it with the text embedding ONNX model
|
||||
|
||||
3. Multimodal Fusion:
|
||||
- Combine visual features and text embeddings into a single input
|
||||
|
||||
4. Text Generation:
|
||||
- The combined input is fed into the text generation ONNX model.
|
||||
- The model generates text tokens one by one in an autoregressive manner.
|
||||
- For each token, the model uses past key/value states to maintain context.
|
||||
|
||||
The specific configuration for the model can be found in `data/genai_config.json`.
|
||||
|
||||
## Limitations and Performance
|
||||
|
||||
This example currently only supports single image input.
|
||||
|
||||
The performance of ONNX-based LLM inference can be relatively slow, especially on CPU:
|
||||
|
||||
- On an Apple M1 Pro:
|
||||
- For image+text input (about 300 tokens): ~5 seconds per output token
|
||||
- For text-only input (about 10 tokens): ~200ms per output token
|
||||
|
||||
## Run this Example
|
||||
|
||||
Before running the example, you'll need to download the ONNX model files to the `data` directory. At present, the `SessionBuilder.commit_from_url` method doesn't support initialization for models split into `.onnx` and `.onnx.data` files, which is the case for Phi-3 Vision models.
|
||||
|
||||
To get started, use the `/data/download.sh` script to download the following three model files:
|
||||
|
||||
1. `phi-3-v-128k-instruct-vision.onnx` and `phi-3-v-128k-instruct-vision.onnx.data`
|
||||
2. `phi-3-v-128k-instruct-text-embedding.onnx` and `phi-3-v-128k-instruct-text-embedding.onnx.data`
|
||||
3. `phi-3-v-128k-instruct-text.onnx` and `phi-3-v-128k-instruct-text.onnx.data`
|
||||
4. `tokenizer.json`
|
||||
|
||||
Once the model files are downloaded, you can run the example using Cargo:
|
||||
|
||||
```bash
|
||||
cargo run -p phi-3-vision
|
||||
```
|
||||
5
examples/phi-3-vision/build.rs
Normal file
5
examples/phi-3-vision/build.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
fn main() {
|
||||
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
|
||||
#[cfg(target_os = "macos")]
|
||||
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
|
||||
}
|
||||
3
examples/phi-3-vision/data/.gitignore
vendored
Normal file
3
examples/phi-3-vision/data/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/*.onnx
|
||||
/*.onnx.data
|
||||
/tokenizer.json
|
||||
16
examples/phi-3-vision/data/download.sh
Normal file
16
examples/phi-3-vision/data/download.sh
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
BASE_URL="https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/resolve/main/cpu-int4-rtn-block-32-acc-level-4/"
|
||||
FILES=(
|
||||
"phi-3-v-128k-instruct-text-embedding.onnx"
|
||||
"phi-3-v-128k-instruct-text-embedding.onnx.data"
|
||||
"phi-3-v-128k-instruct-text.onnx"
|
||||
"phi-3-v-128k-instruct-text.onnx.data"
|
||||
"phi-3-v-128k-instruct-vision.onnx"
|
||||
"phi-3-v-128k-instruct-vision.onnx.data"
|
||||
"tokenizer.json"
|
||||
)
|
||||
|
||||
for FILE in "${FILES[@]}"; do
|
||||
wget "${BASE_URL}${FILE}"
|
||||
done
|
||||
BIN
examples/phi-3-vision/data/example.jpg
Normal file
BIN
examples/phi-3-vision/data/example.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 107 KiB |
68
examples/phi-3-vision/data/genai_config.json
Normal file
68
examples/phi-3-vision/data/genai_config.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"model": {
|
||||
"bos_token_id": 1,
|
||||
"context_length": 131072,
|
||||
"decoder": {
|
||||
"session_options": {
|
||||
"log_id": "onnxruntime-genai",
|
||||
"provider_options": []
|
||||
},
|
||||
"filename": "phi-3-v-128k-instruct-text.onnx",
|
||||
"head_size": 96,
|
||||
"hidden_size": 3072,
|
||||
"inputs": {
|
||||
"inputs_embeds": "inputs_embeds",
|
||||
"attention_mask": "attention_mask",
|
||||
"past_key_names": "past_key_values.%d.key",
|
||||
"past_value_names": "past_key_values.%d.value"
|
||||
},
|
||||
"outputs": {
|
||||
"logits": "logits",
|
||||
"present_key_names": "present.%d.key",
|
||||
"present_value_names": "present.%d.value"
|
||||
},
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32
|
||||
},
|
||||
"embedding": {
|
||||
"filename": "phi-3-v-128k-instruct-text-embedding.onnx",
|
||||
"inputs": {
|
||||
"input_ids": "input_ids"
|
||||
},
|
||||
"outputs": {
|
||||
"inputs_embeds": "inputs_embeds"
|
||||
}
|
||||
},
|
||||
"vision": {
|
||||
"filename": "phi-3-v-128k-instruct-vision.onnx",
|
||||
"inputs": {
|
||||
"pixel_values": "pixel_values",
|
||||
"image_sizes": "image_sizes"
|
||||
},
|
||||
"outputs": {
|
||||
"visual_features": "visual_features"
|
||||
}
|
||||
},
|
||||
"eos_token_id": 32007,
|
||||
"pad_token_id": 32000,
|
||||
"type": "phi3v",
|
||||
"vocab_size": 32064
|
||||
},
|
||||
"search": {
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"early_stopping": true,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 131072,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"past_present_share_buffer": true,
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0
|
||||
}
|
||||
}
|
||||
192
examples/phi-3-vision/src/image_process.rs
Normal file
192
examples/phi-3-vision/src/image_process.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! This file is a Rust implementation of the image processing code for Phi-3-vision-128k-instruct model.
|
||||
//! The original Python version can be found at:
|
||||
//! https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
//!
|
||||
//! The image transformation is configured as Phi3ImageTransform in the processor config:
|
||||
//! https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/processor_config.json
|
||||
//!
|
||||
//! This Rust implementation aims to provide similar functionality for preprocessing images
|
||||
//! to be used with the Phi-3 vision model, adapting the original Python code to Rust.
|
||||
use anyhow::Result;
|
||||
use image::{DynamicImage, GenericImageView, ImageBuffer};
|
||||
use ndarray::{s, Array2, Array4, Array5, Axis};
|
||||
|
||||
/// see https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/processor_config.json
|
||||
/// NOTE: The default setting in processor_config.json is num_crops = 16,
|
||||
/// but this is too slow for practical use. We use 1 here for better performance.
|
||||
pub const NUM_CROPS: usize = 1;
|
||||
pub const _NUM_IMG_TOKENS: usize = 144;
|
||||
|
||||
const OPENAI_CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
|
||||
const OPENAI_CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
|
||||
|
||||
pub struct Phi3VImageProcessor {
|
||||
num_crops: usize,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
do_convert_rgb: bool,
|
||||
}
|
||||
|
||||
impl Phi3VImageProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
num_crops: NUM_CROPS,
|
||||
image_mean: OPENAI_CLIP_MEAN.to_vec(),
|
||||
image_std: OPENAI_CLIP_STD.to_vec(),
|
||||
do_convert_rgb: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn _calc_num_image_tokens(&self, image: &DynamicImage) -> usize {
|
||||
let transformed = self.hd_transform(image);
|
||||
let (width, height) = transformed.dimensions();
|
||||
self.calc_num_image_tokens_from_image_size(width, height)
|
||||
}
|
||||
|
||||
pub fn calc_num_image_tokens_from_image_size(&self, width: u32, height: u32) -> usize {
|
||||
let (new_width, new_height) = self.calc_hd_transform_size(width, height);
|
||||
((new_height / 336 * new_width / 336 + 1) * 144 + 1 + (new_height / 336 + 1) * 12) as usize
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, image: &DynamicImage) -> Result<BatchFeature> {
|
||||
let rgb_image = if self.do_convert_rgb { image.to_rgb8() } else { image.to_rgb8() };
|
||||
let rgb_image = DynamicImage::ImageRgb8(rgb_image);
|
||||
|
||||
let transformed = self.hd_transform(&rgb_image);
|
||||
let (width, height) = transformed.dimensions();
|
||||
let shapes = vec![height as i64, width as i64];
|
||||
let image_sizes = Array2::from_shape_vec((1, 2), shapes)?;
|
||||
|
||||
let num_img_tokens = self.calc_num_image_tokens_from_image_size(width, height);
|
||||
|
||||
let normalized = self.normalize_image(&transformed);
|
||||
let global_image = self.create_global_image(&normalized);
|
||||
let local_patches = self.create_local_patches(&normalized);
|
||||
|
||||
let mut all_patches = vec![global_image];
|
||||
all_patches.extend(local_patches);
|
||||
|
||||
let padded_images = self.pad_to_max_num_crops_tensor(&all_patches, self.num_crops + 1);
|
||||
let pixel_values = padded_images.insert_axis(Axis(0));
|
||||
|
||||
Ok(BatchFeature {
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_img_tokens: vec![num_img_tokens as i64],
|
||||
})
|
||||
}
|
||||
|
||||
fn hd_transform(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut transposed = false;
|
||||
let (width, height) = if width < height {
|
||||
transposed = true;
|
||||
(height, width)
|
||||
} else {
|
||||
(width, height)
|
||||
};
|
||||
|
||||
let ratio = width as f32 / height as f32;
|
||||
let mut scale = 1;
|
||||
while (scale as f32 * (scale as f32 / ratio).ceil()) <= self.num_crops as f32 {
|
||||
scale += 1;
|
||||
}
|
||||
scale -= 1;
|
||||
|
||||
let new_width = scale * 336;
|
||||
let new_height = (new_width as f32 / ratio) as u32;
|
||||
|
||||
let resized = image.resize_exact(new_width, new_height, image::imageops::FilterType::Lanczos3);
|
||||
let padded = self.padding_336(&resized);
|
||||
|
||||
if transposed {
|
||||
padded.rotate90()
|
||||
} else {
|
||||
padded
|
||||
}
|
||||
}
|
||||
|
||||
fn padding_336(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let tar = ((height as f32 / 336.0).ceil() * 336.0) as u32;
|
||||
let top_padding = (tar - height) / 2;
|
||||
let mut padded = ImageBuffer::from_pixel(width, tar, image::Rgba([255, 255, 255, 255]));
|
||||
image::imageops::overlay(&mut padded, image, 0, top_padding as i64);
|
||||
DynamicImage::ImageRgba8(padded)
|
||||
}
|
||||
|
||||
fn calc_hd_transform_size(&self, width: u32, height: u32) -> (u32, u32) {
|
||||
let (width, height) = if width < height { (height, width) } else { (width, height) };
|
||||
|
||||
let ratio = width as f32 / height as f32;
|
||||
let mut scale = 1;
|
||||
while (scale as f32 * (scale as f32 / ratio).ceil()) <= self.num_crops as f32 {
|
||||
scale += 1;
|
||||
}
|
||||
scale -= 1;
|
||||
|
||||
let new_width = scale * 336;
|
||||
let new_height = (new_width as f32 / ratio) as u32;
|
||||
|
||||
self.calc_padded_size(new_width, new_height)
|
||||
}
|
||||
|
||||
fn calc_padded_size(&self, width: u32, height: u32) -> (u32, u32) {
|
||||
let target_height = ((height as f32 / 336.0).ceil() * 336.0) as u32;
|
||||
(width, target_height)
|
||||
}
|
||||
|
||||
fn normalize_image(&self, image: &DynamicImage) -> Array4<f32> {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut normalized = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
|
||||
|
||||
for (x, y, pixel) in image.pixels() {
|
||||
for c in 0..3 {
|
||||
normalized[[0, c, y as usize, x as usize]] = (pixel[c] as f32 / 255.0 - self.image_mean[c]) / self.image_std[c];
|
||||
}
|
||||
}
|
||||
|
||||
normalized
|
||||
}
|
||||
|
||||
fn create_global_image(&self, _image: &Array4<f32>) -> Array4<f32> {
|
||||
Array4::<f32>::zeros((1, 3, 336, 336))
|
||||
}
|
||||
|
||||
fn create_local_patches(&self, image: &Array4<f32>) -> Vec<Array4<f32>> {
|
||||
let (_, _, height, width) = image.dim();
|
||||
let mut patches = Vec::new();
|
||||
|
||||
for h in (0..height).step_by(336) {
|
||||
for w in (0..width).step_by(336) {
|
||||
let patch = image
|
||||
.slice(s![.., .., h..std::cmp::min(h + 336, height), w..std::cmp::min(w + 336, width)])
|
||||
.to_owned();
|
||||
patches.push(patch);
|
||||
}
|
||||
}
|
||||
|
||||
patches
|
||||
}
|
||||
|
||||
fn pad_to_max_num_crops_tensor(&self, patches: &[Array4<f32>], max_crops: usize) -> Array4<f32> {
|
||||
let (_, channels, height, width) = patches[0].dim();
|
||||
let mut padded = Array4::<f32>::zeros((max_crops, channels, height, width));
|
||||
|
||||
for (i, patch) in patches.iter().enumerate() {
|
||||
if i >= max_crops {
|
||||
break;
|
||||
}
|
||||
// Remove the extra dimension when assigning
|
||||
padded.slice_mut(s![i, .., .., ..]).assign(&patch.slice(s![0, .., .., ..]));
|
||||
}
|
||||
|
||||
padded
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BatchFeature {
|
||||
pub pixel_values: Array5<f32>,
|
||||
pub image_sizes: Array2<i64>,
|
||||
pub num_img_tokens: Vec<i64>,
|
||||
}
|
||||
240
examples/phi-3-vision/src/main.rs
Normal file
240
examples/phi-3-vision/src/main.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
mod image_process;
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
use ndarray::{s, Array, Array2, Array3, Array4, ArrayView, Ix3, Ix4};
|
||||
use ort::{Session, Tensor};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const VISION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-vision.onnx";
|
||||
const TEXT_EMBEDDING_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text-embedding.onnx";
|
||||
const GENERATION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text.onnx";
|
||||
|
||||
const MAX_LENGTH: usize = 100; // max length of the generated text
|
||||
const EOS_TOKEN_ID: i64 = 32007; // <|end|>
|
||||
const USER_TOKEN_ID: i64 = 32010; // <|user|>
|
||||
const VOCAB_SIZE: usize = 32064;
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn get_current_time() -> Instant {
|
||||
Instant::now()
|
||||
}
|
||||
|
||||
fn get_image_embedding(vision_model: &Session, img: &Option<DynamicImage>) -> Result<Array3<f32>> {
|
||||
let visual_features = if let Some(img) = img {
|
||||
let image_processor = image_process::Phi3VImageProcessor::new();
|
||||
let result = image_processor.preprocess(img)?;
|
||||
tracing::debug!(
|
||||
"image process result, num_img_tokens: {num_img_tokens:?}, pixel_values: {pixel_values:?}, image_sizes: {image_sizes:?}",
|
||||
num_img_tokens = result.num_img_tokens,
|
||||
pixel_values = result.pixel_values.shape(),
|
||||
image_sizes = result.image_sizes.shape(),
|
||||
);
|
||||
let model_inputs = ort::inputs![
|
||||
"pixel_values" => result.pixel_values,
|
||||
"image_sizes" => result.image_sizes,
|
||||
]?;
|
||||
let outputs = vision_model.run(model_inputs)?;
|
||||
let predictions_view: ArrayView<f32, _> = outputs["visual_features"].try_extract_tensor::<f32>()?;
|
||||
let predictions = predictions_view.into_dimensionality::<Ix3>()?.to_owned();
|
||||
predictions
|
||||
} else {
|
||||
Array::zeros((1, 0, 0))
|
||||
};
|
||||
Ok(visual_features)
|
||||
}
|
||||
|
||||
fn get_text_embedding(text_embedding_model: &Session, input_ids: &Array2<i64>) -> Result<Array3<f32>> {
|
||||
let model_inputs = ort::inputs![
|
||||
"input_ids" => input_ids.to_owned(),
|
||||
]?;
|
||||
let outputs = text_embedding_model.run(model_inputs)?;
|
||||
let inputs_embeds_view: ArrayView<f32, _> = outputs["inputs_embeds"].try_extract_tensor::<f32>()?;
|
||||
let inputs_embeds = inputs_embeds_view.into_dimensionality::<Ix3>()?.to_owned();
|
||||
Ok(inputs_embeds)
|
||||
}
|
||||
|
||||
fn merge_text_and_image_embeddings(
|
||||
inputs_embeds: &Array3<f32>,
|
||||
attention_mask: &Array2<i64>,
|
||||
visual_features: &Array3<f32>,
|
||||
image_token_position: usize,
|
||||
) -> (Array3<f32>, Array2<i64>) {
|
||||
let mut combined_embeds = Array3::zeros((1, inputs_embeds.shape()[1] + visual_features.shape()[1], inputs_embeds.shape()[2]));
|
||||
|
||||
// Copy text embeddings up to the <|image_1|> token
|
||||
combined_embeds
|
||||
.slice_mut(s![.., ..image_token_position, ..])
|
||||
.assign(&inputs_embeds.slice(s![.., ..image_token_position, ..]));
|
||||
|
||||
// Insert visual features
|
||||
combined_embeds
|
||||
.slice_mut(s![.., image_token_position..(image_token_position + visual_features.shape()[1]), ..])
|
||||
.assign(&visual_features);
|
||||
|
||||
// Copy the remaining text embeddings
|
||||
combined_embeds
|
||||
.slice_mut(s![.., (image_token_position + visual_features.shape()[1]).., ..])
|
||||
.assign(&inputs_embeds.slice(s![.., image_token_position.., ..]));
|
||||
|
||||
// Update attention_mask
|
||||
let mut new_attention_mask = Array2::ones((1, attention_mask.shape()[1] + visual_features.shape()[1]));
|
||||
new_attention_mask
|
||||
.slice_mut(s![.., ..image_token_position])
|
||||
.assign(&attention_mask.slice(s![.., ..image_token_position]));
|
||||
new_attention_mask
|
||||
.slice_mut(s![.., (image_token_position + visual_features.shape()[1])..])
|
||||
.assign(&attention_mask.slice(s![.., image_token_position..]));
|
||||
|
||||
(combined_embeds, new_attention_mask)
|
||||
}
|
||||
|
||||
/// see https://github.com/microsoft/onnxruntime-genai/blob/main/examples/python/phi3v.py
|
||||
/// <|user|><|image_1|>{text}<|end|><|assistant|>
|
||||
/// Includes the `<s>` token, which is typically used as the BOS (Beginning of Sequence) token by LlamaTokenizer
|
||||
fn format_chat_template(img: &Option<DynamicImage>, txt: &str) -> String {
|
||||
match img {
|
||||
Some(_) => format!("<s><|user|>\n<|image_1|>\n{txt}<|end|>\n<|assistant|>\n", txt = txt),
|
||||
None => format!("<s><|user|>\n{txt}<|end|>\n<|assistant|>\n", txt = txt),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_text(
|
||||
tokenizer: &Tokenizer,
|
||||
vision_model: &Session,
|
||||
text_embedding_model: &Session,
|
||||
generation_model: &Session,
|
||||
image: &Option<DynamicImage>,
|
||||
text: &str,
|
||||
) -> Result<()> {
|
||||
let (mut inputs_embeds, mut attention_mask) = {
|
||||
let visual_features = get_image_embedding(&vision_model, &image)?;
|
||||
let prompt = format_chat_template(&image, text);
|
||||
let encoding = tokenizer.encode(prompt, true).map_err(|e| anyhow::anyhow!("Error encoding: {:?}", e))?;
|
||||
|
||||
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
|
||||
let input_ids: Array2<i64> = Array2::from_shape_vec((1, input_ids.len()), input_ids)?;
|
||||
let mut inputs_embeds: Array3<f32> = get_text_embedding(&text_embedding_model, &input_ids)?;
|
||||
|
||||
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&mask| mask as i64).collect();
|
||||
let mut attention_mask: Array2<i64> = Array2::from_shape_vec((1, attention_mask.len()), attention_mask)?;
|
||||
|
||||
if image.is_some() {
|
||||
// Find the position of the <|image_1|> token, which is after <|user|>
|
||||
let image_token_position = input_ids.iter().position(|&id| id == USER_TOKEN_ID).unwrap_or(0);
|
||||
(inputs_embeds, attention_mask) = merge_text_and_image_embeddings(&inputs_embeds, &attention_mask, &visual_features, image_token_position);
|
||||
};
|
||||
(inputs_embeds, attention_mask)
|
||||
};
|
||||
|
||||
// Initialize past_key_values for the transformer model
|
||||
// This is used to store the attention mechanism's state across multiple inference steps
|
||||
// The structure is:
|
||||
// - 64 elements (32 layers, each with a key and value)
|
||||
// - Each element is a 4D array with dimensions:
|
||||
// 1. Batch size (1)
|
||||
// 2. Number of attention heads (32)
|
||||
// 3. Sequence length (0 initially, will grow with each token generated)
|
||||
// 4. Head size (96)
|
||||
let mut past_key_values: Vec<Array4<f32>> = vec![Array4::zeros((1, 32, 0, 96)); 64];
|
||||
let mut generated_tokens: Vec<i64> = Vec::new();
|
||||
// Loop until <|end|> token is generated or max length is reached
|
||||
for _ in 0..MAX_LENGTH {
|
||||
// Prepare model inputs
|
||||
let model_inputs = {
|
||||
let mut model_inputs = ort::inputs![
|
||||
"inputs_embeds" => inputs_embeds.clone(),
|
||||
"attention_mask" => attention_mask.clone(),
|
||||
]?;
|
||||
for i in 0..32 {
|
||||
model_inputs.push((format!("past_key_values.{}.key", i).into(), Tensor::from_array(past_key_values[i * 2].view())?.into()));
|
||||
model_inputs.push((format!("past_key_values.{}.value", i).into(), Tensor::from_array(past_key_values[i * 2 + 1].view())?.into()));
|
||||
}
|
||||
model_inputs
|
||||
};
|
||||
|
||||
// Run the model
|
||||
let model_outputs = generation_model.run(model_inputs)?;
|
||||
// Get the logits for the last token. Logits are unnormalized log probabilities, with a shape of [1, 1, VOCAB_SIZE],
|
||||
// where VOCAB_SIZE is the total number of unique tokens in the model's vocabulary.
|
||||
//
|
||||
// The current implementation uses a simple greedy decoding strategy:
|
||||
// - We select the token with the highest probability (argmax) from the logits.
|
||||
// - This approach always chooses the most likely next token, which can lead to deterministic and potentially repetitive outputs.
|
||||
//
|
||||
// Note: More advanced sampling strategies (e.g., temperature scaling, top-k, top-p sampling) are not implemented in the current version.
|
||||
//
|
||||
// The selected token ID will be in the range [0, VOCAB_SIZE - 1].
|
||||
let logits: ArrayView<f32, _> = model_outputs["logits"].try_extract_tensor::<f32>()?.into_dimensionality::<Ix3>()?;
|
||||
let next_token_id = logits
|
||||
.slice(s![0, -1, ..VOCAB_SIZE])
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.0 as i64;
|
||||
generated_tokens.push(next_token_id);
|
||||
// Log the generated text
|
||||
let output_ids: Vec<u32> = generated_tokens.iter().map(|&id| id as u32).collect();
|
||||
let generated_text = tokenizer.decode(&output_ids, false).unwrap();
|
||||
tracing::info!("Generated text: {}", generated_text);
|
||||
|
||||
if next_token_id == EOS_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
|
||||
// Update inputs_embeds, attention_mask, and past_key_values for the next iteration
|
||||
(inputs_embeds, attention_mask) = {
|
||||
let new_token_id = Array2::from_elem((1, 1), next_token_id);
|
||||
let new_token_embed = get_text_embedding(&text_embedding_model, &new_token_id)?;
|
||||
// Merge the new token embedding with the previous embeddings
|
||||
let mut combined_embeds = Array3::zeros((inputs_embeds.shape()[0], inputs_embeds.shape()[1] + 1, inputs_embeds.shape()[2]));
|
||||
combined_embeds.slice_mut(s![.., ..inputs_embeds.shape()[1], ..]).assign(&inputs_embeds);
|
||||
combined_embeds.slice_mut(s![.., inputs_embeds.shape()[1].., ..]).assign(&new_token_embed);
|
||||
let new_attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
|
||||
(combined_embeds, new_attention_mask)
|
||||
};
|
||||
for i in 0..32 {
|
||||
past_key_values[i * 2] = model_outputs[format!("present.{}.key", i)]
|
||||
.try_extract_tensor::<f32>()?
|
||||
.into_dimensionality::<Ix4>()?
|
||||
.to_owned();
|
||||
past_key_values[i * 2 + 1] = model_outputs[format!("present.{}.value", i)]
|
||||
.try_extract_tensor::<f32>()?
|
||||
.into_dimensionality::<Ix4>()?
|
||||
.to_owned();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt().init(); // set up default subscriber with log level `INFO`
|
||||
|
||||
let data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("data");
|
||||
let tokenizer = Tokenizer::from_file(data_dir.join("tokenizer.json")).map_err(|e| anyhow::anyhow!("Error loading tokenizer: {:?}", e))?;
|
||||
let vision_model = Session::builder()?
|
||||
.with_execution_providers([ort::CPUExecutionProvider::default().build()])?
|
||||
.commit_from_file(data_dir.join(VISION_MODEL_NAME))?;
|
||||
let text_embedding_model = Session::builder()?
|
||||
.with_execution_providers([ort::CPUExecutionProvider::default().build()])?
|
||||
.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?;
|
||||
let generation_model = Session::builder()?
|
||||
.with_execution_providers([ort::CPUExecutionProvider::default().build()])?
|
||||
.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?;
|
||||
|
||||
// Generate text from text
|
||||
let image: Option<DynamicImage> = None;
|
||||
let text = "Who are you?".to_string();
|
||||
generate_text(&tokenizer, &vision_model, &text_embedding_model, &generation_model, &image, &text).await?;
|
||||
|
||||
// Generate text from image and text
|
||||
let image: Option<DynamicImage> = Some(image::open(data_dir.join("example.jpg"))?);
|
||||
let text = "What is shown in this image?".to_string();
|
||||
generate_text(&tokenizer, &vision_model, &text_embedding_model, &generation_model, &image, &text).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user