feat: training (#202)

This commit is contained in:
Carson M
2024-07-06 11:07:27 -05:00
committed by GitHub
parent 0407adb5cc
commit 0a43482d03
21 changed files with 1294 additions and 33 deletions

5
.gitignore vendored
View File

@@ -186,6 +186,7 @@ WixTools/
# ONNX Runtime downloaded models
**/*.onnx
**/*.ort
**/*.pbseq
!examples/webassembly/**/*.ort
!tests/data/*.onnx
!tests/data/*.ort
@@ -196,4 +197,8 @@ WixTools/
# Glassbench results
/glassbench*.db
# Python virtual environment
.venv*
# Training checkpoints
tools/train-data/**/checkpoint

View File

@@ -7,6 +7,7 @@ members = [
'examples/model-info',
'examples/yolov8',
'examples/modnet',
'examples/training',
'examples/webassembly'
]
default-members = [
@@ -45,13 +46,15 @@ strip = true
codegen-units = 1
[package.metadata.docs.rs]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
rustdoc-args = [ "--cfg", "docsrs" ]
[features]
default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ]
training = [ "ort-sys/training" ]
operator-libraries = [ "libc", "winapi" ]
fetch-models = [ "ureq" ]

View File

@@ -0,0 +1,18 @@
[package]
publish = false
name = "example-training"
version = "0.0.0"
edition = "2021"
[dependencies]
ort = { path = "../../", features = [ "training" ] }
ndarray = "0.15"
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
rand = "0.8"
simd-json = "0.13"
kdam = "0.5"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
[features]
load-dynamic = [ "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]

View File

@@ -0,0 +1,26 @@
# Training Examples
## `train-clm`
This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device.
To get started, create a Python virtual environment and install the following packages:
```
pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3
```
We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts.
Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo.
Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate.
While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text:
```
100%|██████████████████████████████████████| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611]
I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|>
```
Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute.
## `train-clm-simple`
This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to 🤗 Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you!

View File

@@ -0,0 +1,5 @@
fn main() {
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
#[cfg(target_os = "macos")]
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
}

View File

@@ -0,0 +1,44 @@
use std::{
env,
fs::File,
io::{BufRead, BufReader, BufWriter, Write},
path::Path
};
use simd_json::derived::ValueObjectAccessAsScalar;
use tokenizers::Tokenizer;
const MAX_TOKENS: usize = 500_000;
fn main() {
let input = env::args().nth(1).expect("provide input jsonl");
let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into());
let input = BufReader::new(File::open(input).unwrap());
let mut output = BufWriter::new(File::create(output).unwrap());
let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();
let mut bytes_written = 0;
for line in input.lines() {
let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() };
let tokenized = tokenizer
.encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false)
.unwrap();
let id_bytes: Vec<u8> = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect();
output.write_all(&id_bytes).unwrap();
bytes_written += id_bytes.len();
if bytes_written >= MAX_TOKENS * 2 {
output.flush().unwrap();
break;
}
}
}

View File

@@ -0,0 +1,118 @@
use std::{
fs::File,
io::{Read, Seek, SeekFrom, Write},
path::Path
};
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments};
use rand::RngCore;
use tokenizers::Tokenizer;
const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();
ort::init().commit()?;
let trainer = Trainer::new_from_artifacts(
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
Allocator::default(),
"tools/train-data/mini-clm",
None
)?;
let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();
let mut dataset = File::open("dataset.bin").unwrap();
let file_size = dataset.metadata().unwrap().len();
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
let mut rng = rand::thread_rng();
let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let dataloader = move |_: usize| {
for batch in 0..BATCH_SIZE {
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
}
Ok((
ort::inputs![Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?,
ort::inputs![Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?
))
};
trainer.train(
TrainingArguments::new(dataloader)
.with_lr(7e-5)
.with_max_steps(5000)
.with_ckpt_strategy(CheckpointStrategy::Steps(500))
)?;
trainer.export("trained-clm.onnx", ["probs"])?;
let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
let mut stdout = std::io::stdout();
let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
let mut tokens = Array1::from_iter(tokens.iter().cloned());
for _ in 0..50 {
let array = tokens.view().insert_axis(Axis(0));
let outputs = session.run(ort::inputs![array]?)?;
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;
let probabilities = &mut generated_tokens
.slice(s![-1, ..])
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
let token = probabilities[0].0;
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];
let token_str = tokenizer.decode(&[token as _], false).unwrap();
print!("{}", token_str);
stdout.flush().unwrap();
}
println!();
Ok(())
}

View File

@@ -0,0 +1,133 @@
use std::{
fs::File,
io::{Read, Seek, SeekFrom, Write},
path::Path
};
use kdam::BarExt;
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer};
use rand::RngCore;
use tokenizers::Tokenizer;
const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();
ort::init().commit()?;
kdam::term::init(true);
let _ = kdam::term::hide_cursor();
let trainer = Trainer::new(
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
Allocator::default(),
Checkpoint::load("tools/train-data/mini-clm/checkpoint")?,
"tools/train-data/mini-clm/training_model.onnx",
"tools/train-data/mini-clm/eval_model.onnx",
"tools/train-data/mini-clm/optimizer_model.onnx"
)?;
let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();
let optimizer = trainer.optimizer();
optimizer.set_lr(7e-5)?;
let mut dataset = File::open("dataset.bin").unwrap();
let file_size = dataset.metadata().unwrap().len();
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
let mut rng = rand::thread_rng();
let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut pb = kdam::tqdm!(total = 5000);
for _ in 0..5000 {
for batch in 0..BATCH_SIZE {
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
}
let inputs = Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap();
let labels = Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap();
let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
pb.set_postfix(format!("loss={loss:.3}"));
pb.update(1).unwrap();
if loss.is_nan() {
return Ok(());
}
optimizer.step()?;
optimizer.reset_grad()?;
}
eprintln!();
let _ = kdam::term::show_cursor();
trainer.export("trained-clm.onnx", ["probs"])?;
let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
let mut stdout = std::io::stdout();
let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
let mut tokens = Array1::from_iter(tokens.iter().cloned());
for _ in 0..50 {
let array = tokens.view().insert_axis(Axis(0));
let outputs = session.run(ort::inputs![array]?)?;
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;
let probabilities = &mut generated_tokens
.slice(s![-1, ..])
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
let token = probabilities[0].0;
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];
let token_str = tokenizer.decode(&[token as _], false).unwrap();
print!("{}", token_str);
stdout.flush().unwrap();
}
println!();
Ok(())
}

View File

@@ -16,6 +16,7 @@ include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ]
[features]
default = []
training = []
download-binaries = [ "ureq", "tar", "flate2", "sha2" ]
load-dynamic = []
copy-dylibs = []

View File

@@ -37,12 +37,12 @@ fn fetch_file(source_url: &str) -> Vec<u8> {
buffer
}
fn find_dist(target: &str, designator: &str) -> Option<(&'static str, &'static str)> {
fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> {
DIST_TABLE
.split('\n')
.filter(|c| !c.is_empty() && !c.starts_with('#'))
.map(|c| c.split('\t').collect::<Vec<_>>())
.find(|c| c[0] == designator && c[1] == target)
.find(|c| c[0] == feature_set && c[1] == target)
.map(|c| (c[2], c[3]))
}
@@ -341,23 +341,31 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
#[cfg(feature = "download-binaries")]
{
let target = env::var("TARGET").unwrap().to_string();
let designator = if cfg!(any(feature = "cuda", feature = "tensorrt")) {
if lib_exists("cudart64_12.dll") || lib_exists("libcudart.so.12") { "cu12" } else { "cu11" }
let mut feature_set = Vec::new();
if cfg!(feature = "training") {
feature_set.push("train");
}
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") {
feature_set.push("cu11");
} else {
feature_set.push("cu12");
}
} else if cfg!(feature = "rocm") {
"rocm"
} else {
"none"
};
let mut dist = find_dist(&target, designator);
if dist.is_none() && designator != "none" {
feature_set.push("rocm");
}
let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() };
let mut dist = find_dist(&target, &feature_set);
if dist.is_none() && feature_set != "none" {
dist = find_dist(&target, "none");
}
if dist.is_none() {
panic!(
"downloaded binaries not available for target {target}{}\nyou may have to compile ONNX Runtime from source",
if designator != "none" {
format!(" (note: also requested `{designator}`)")
if feature_set != "none" {
format!(" (note: also requested features `{feature_set}`)")
} else {
String::new()
}

View File

@@ -4,12 +4,26 @@ cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/
rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz 84F74428E0BEC68C55B8E1E91B9282E984CD2866148A2584382B8CB3284214A3
none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-unknown-linux-gnu.tgz 0A193706A95286853D792D7D9B2271CBEA35C57F249943FE811CED97E0E24862
train aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-unknown-linux-gnu.tgz C04DBEAF19F2BCD3643F8F7D7FA01110A1AF429DFDD1C1DC7C5EDA2B1A8AA324
train,cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-unknown-linux-gnu.tgz A139D8AD8930930F5A61DF112C8275AAD1F0415FAFD08CE3031CEFFFC30F2445
train,cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-unknown-linux-gnu.tgz 2DAA2E2CF44E9B9A96AB2E9C4271C35189C96BF264D1797DABCF1D6711730DE7
train,rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz DD373BA6B251D21953223B2FBB64F4DF34CFE98A63C26D16607BEAC6BC788466
train x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-unknown-linux-gnu.tgz 0E617970AE83ABE5FB9A3D5D69AAC9A67ED4C9494AD527B14A84FDC98CA9B924
none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-pc-windows-msvc.tgz B2F962F0E75F17F3D657B3504CE891BAA6461B26AF65FBD9244B3CCA17FD79D4
cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12-v1.18.1-x86_64-pc-windows-msvc.tgz CDBC2D87B202E1847900E94796D102EE4D5C19A9568BBD014838ECD1F5D5350B
cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu11-v1.18.1-x86_64-pc-windows-msvc.tgz B514FC25453F955F8592100448B27F5E1762A344E8C2D57D41B908978EF2A126
none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-pc-windows-msvc.tgz EB2BCD1778C5934437D4C5B17F67DEAF5F67D2E3C18C7298973EACD41113DC01
train aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-pc-windows-msvc.tgz 8CC1FFFD8AB5E526A076C29A767A650C436E31179D0C6E52C2EA936067B72566
train,cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-pc-windows-msvc.tgz 6AF64567E25B59AD1196D4953EF8C6A65795E8A4B864E10D8303A027AC50B2D0
train,cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-pc-windows-msvc.tgz E14AA0F4FBBBCAF925AD4DB4F76B06402F654B36C5F221E00010D1005F47AE56
train x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-pc-windows-msvc.tgz 84728438E5A950027EBBDC51463F4E5B99B4979087F0F127EA18BC604507E979
none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-apple-darwin.tgz B42BE76AFB9495983A6D5D498D56D5E685B018F1011EF4C5B8C56124B192FD37
none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-apple-darwin.tgz 247F73A5B3665A6660DFB35213E6FEAAC6ED6CAC5816DD85A348DF790F60A30B
train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-apple-darwin.tgz 29DC09AFA5C3619CF3125F3D55DD64E5EE64451D6BD0044527776849AADEE344
train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-apple-darwin.tgz 898EC9E3F852843ECDB618CF8E317F4C92BDEB33FC773038960857BCB37CB347
none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF

View File

@@ -823,9 +823,117 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() {
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct OrtTrainingApi {
pub struct OrtTrainingSession {
_unused: [u8; 0]
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct OrtCheckpointState {
_unused: [u8; 0]
}
#[repr(i32)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub enum OrtPropertyType {
OrtIntProperty = 0,
OrtFloatProperty = 1,
OrtStringProperty = 2
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct OrtTrainingApi {
pub LoadCheckpoint:
::std::option::Option<_system!(unsafe fn(checkpoint_path: *const ortchar, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr)>,
pub SaveCheckpoint: ::std::option::Option<
_system!(unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ortchar, include_optimizer_state: bool) -> OrtStatusPtr)
>,
pub CreateTrainingSession: ::std::option::Option<
_system!(
unsafe fn(
env: *const OrtEnv,
options: *const OrtSessionOptions,
checkpoint_state: *mut OrtCheckpointState,
train_model_path: *const ortchar,
eval_model_path: *const ortchar,
optimizer_model_path: *const ortchar,
out: *mut *mut OrtTrainingSession
) -> OrtStatusPtr
)
>,
pub CreateTrainingSessionFromBuffer: ::std::option::Option<
_system!(
unsafe fn(
env: *const OrtEnv,
options: *const OrtSessionOptions,
checkpoint_state: *mut OrtCheckpointState,
train_model_data: *const (),
train_data_length: size_t,
eval_model_data: *const (),
eval_data_length: size_t,
optimizer_model_data: *const (),
optimizer_data_length: size_t,
out: *mut *mut OrtTrainingSession
) -> OrtStatusPtr
)
>,
pub TrainingSessionGetTrainingModelOutputCount:
::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>,
pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>,
pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option<
_system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr)
>,
pub TrainingSessionGetEvalModelOutputName: ::std::option::Option<
_system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr)
>,
pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>,
pub TrainStep: ::std::option::Option<
_system!(
unsafe fn(
session: *mut OrtTrainingSession,
run_options: *const OrtRunOptions,
inputs_len: size_t,
inputs: *const *const OrtValue,
outputs_len: size_t,
outputs: *mut *mut OrtValue
) -> OrtStatusPtr
)
>,
pub EvalStep: ::std::option::Option<
_system!(
unsafe fn(
session: *mut OrtTrainingSession,
run_options: *const OrtRunOptions,
inputs_len: size_t,
inputs: *const *const OrtValue,
outputs_len: size_t,
outputs: *mut *mut OrtValue
) -> OrtStatusPtr
)
>,
pub SetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: f32) -> OrtStatusPtr)>,
pub GetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: *mut f32) -> OrtStatusPtr)>,
pub OptimizerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, run_options: *const OrtRunOptions) -> OrtStatusPtr)>,
pub RegisterLinearLRScheduler: ::std::option::Option<
_system!(unsafe fn(session: *mut OrtTrainingSession, warmup_step_count: i64, total_step_count: i64, initial_lr: f32) -> OrtStatusPtr)
>,
pub SchedulerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>,
pub GetParametersSize: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, out: *mut size_t, trainable_only: bool) -> OrtStatusPtr)>,
pub CopyParametersToBuffer:
::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>,
pub CopyBufferToParameters:
::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>,
pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>,
pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))>,
pub ExportModelForInferencing: ::std::option::Option<
_system!(
unsafe fn(
session: *mut OrtTrainingSession,
inference_model_path: *const ortchar,
graph_outputs_len: usize,
graph_output_names: *const *const c_char
) -> OrtStatusPtr
)
>
}
#[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"]
#[repr(C)]
#[derive(Debug, Copy, Clone)]

View File

@@ -261,7 +261,9 @@ pub enum Error {
#[error("Could't get `AllocatorType` from memory info: {0}")]
GetAllocatorType(ErrorInternal),
#[error("Could't get device ID from memory info: {0}")]
GetDeviceId(ErrorInternal)
GetDeviceId(ErrorInternal),
#[error("Training API is not enabled in this build of ONNX Runtime.")]
TrainingNotEnabled
}
impl Error {

View File

@@ -23,6 +23,9 @@ pub(crate) mod metadata;
pub(crate) mod operator;
pub(crate) mod session;
pub(crate) mod tensor;
#[cfg(feature = "training")]
pub(crate) mod training;
pub(crate) mod util;
pub(crate) mod value;
#[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
#[cfg(target_arch = "wasm32")]
@@ -66,6 +69,9 @@ pub use self::session::{
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub use self::tensor::ArrayExtensions;
pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};
#[cfg(feature = "training")]
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
pub use self::training::*;
pub use self::value::{
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,

View File

@@ -1,9 +1,5 @@
#[cfg(any(feature = "operator-libraries", not(windows)))]
use std::ffi::CString;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
#[cfg(not(target_arch = "wasm32"))]
use std::path::Path;
#[cfg(feature = "fetch-models")]
@@ -316,20 +312,7 @@ impl SessionBuilder {
});
}
// Build an OsString, then a vector of bytes to pass to C
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) // Make sure we have a null terminated string
.collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_encoded_bytes()
.iter()
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
.map(|b| *b as std::os::raw::c_char)
.collect();
let model_path = crate::util::path_to_os_char(model_filepath);
let env = get_environment()?;
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;

142
src/training/mod.rs Normal file
View File

@@ -0,0 +1,142 @@
use std::{
path::Path,
ptr::{self, NonNull},
sync::{
atomic::{AtomicPtr, Ordering},
OnceLock
}
};
use crate::{ortsys, Error, Result, RunOptions};
mod simple;
mod trainer;
pub use self::{
simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments},
trainer::Trainer
};
pub(crate) static TRAINING_API: OnceLock<AtomicPtr<ort_sys::OrtTrainingApi>> = OnceLock::new();
/// Returns a pointer to the global [`ort_sys::OrtTrainingApi`] object, or errors if the Training API is not enabled.
///
/// # Panics
/// May panic if:
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
pub fn training_api() -> Result<NonNull<ort_sys::OrtTrainingApi>> {
NonNull::new(
TRAINING_API
.get_or_init(|| {
let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)];
AtomicPtr::new(training_api.cast_mut())
})
.load(Ordering::Relaxed)
)
.ok_or(Error::TrainingNotEnabled)
}
macro_rules! trainsys {
($method:ident) => {
$crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))
};
(unsafe $method:ident) => {
unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) }
};
($method:ident($($n:expr),+ $(,)?)) => {
$crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)
};
(unsafe $method:ident($($n:expr),+ $(,)?)) => {
unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }
};
($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
$crate::error::status_to_result($crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e)
};
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
$crate::error::status_to_result(unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e)
};
($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+);
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{
let _x = unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+
_x
}};
($method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?;
};
(unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?;
};
($method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
$crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
}};
}
pub(crate) use trainsys;
#[derive(Debug)]
pub struct Checkpoint {
pub(crate) ptr: NonNull<ort_sys::OrtCheckpointState>
}
impl Checkpoint {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = crate::util::path_to_os_char(path);
let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut();
trainsys![unsafe LoadCheckpoint(path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)];
Ok(Checkpoint {
ptr: unsafe { NonNull::new_unchecked(ptr) }
})
}
pub fn save(&self, path: impl AsRef<Path>, include_optimizer_state: bool) -> Result<()> {
let path = crate::util::path_to_os_char(path);
trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state) -> Error::CreateSession];
Ok(())
}
}
impl Drop for Checkpoint {
fn drop(&mut self) {
tracing::trace!("dropping checkpoint");
trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())];
}
}
#[derive(Debug)]
pub struct Optimizer(NonNull<ort_sys::OrtTrainingSession>);
impl Optimizer {
pub fn reset_grad(&self) -> Result<()> {
trainsys![unsafe LazyResetGrad(self.0.as_ptr()) -> Error::CreateSession];
Ok(())
}
pub fn lr(&self) -> Result<f32> {
let mut lr = f32::NAN;
trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr) -> Error::CreateSession];
Ok(lr)
}
pub fn set_lr(&self, lr: f32) -> Result<()> {
trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr) -> Error::CreateSession];
Ok(())
}
pub fn step(&self) -> Result<()> {
self.step_with_options(RunOptions::new()?)
}
pub fn step_with_options(&self, options: RunOptions) -> Result<()> {
trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.run_options_ptr.as_ptr()) -> Error::CreateSession];
Ok(())
}
}

240
src/training/simple.rs Normal file
View File

@@ -0,0 +1,240 @@
use std::{collections::VecDeque, fs, path::PathBuf};
use crate::{Result, SessionInputs};
#[allow(clippy::len_without_is_empty)]
pub trait DataLoader<I, L> {
fn load(&mut self, idx: usize) -> Result<(I, L)>;
fn len(&self) -> Option<usize> {
None
}
}
pub struct IterableDataLoader<T, I, L, C: Fn(&T) -> Result<(I, L)>> {
items: Box<[T]>,
collator: C
}
impl<T, I, L, C: Fn(&T) -> Result<(I, L)>> DataLoader<I, L> for IterableDataLoader<T, I, L, C> {
fn load(&mut self, idx: usize) -> Result<(I, L)> {
(self.collator)(&self.items[idx])
}
fn len(&self) -> Option<usize> {
Some(self.items.len())
}
}
pub fn iterable_data_loader<T, I, L, C: Fn(&T) -> Result<(I, L)>>(iterable: impl Iterator<Item = T>, collator: C) -> IterableDataLoader<T, I, L, C> {
IterableDataLoader {
items: iterable.collect::<Vec<T>>().into_boxed_slice(),
collator
}
}
impl<I, L, F: FnMut(usize) -> Result<(I, L)>> DataLoader<I, L> for F {
fn load(&mut self, idx: usize) -> Result<(I, L)> {
(self)(idx)
}
fn len(&self) -> Option<usize> {
None
}
}
pub enum EvaluationStrategy {
None,
Steps(usize),
Epochs(usize)
}
impl EvaluationStrategy {
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
match self {
Self::None => false,
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
Self::Epochs(epochs) => {
if let Some(dataloader_size) = dataloader_size {
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
} else {
false
}
}
}
}
}
pub enum CheckpointStrategy {
None,
Steps(usize),
Epochs(usize)
}
impl CheckpointStrategy {
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
match self {
Self::None => false,
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
Self::Epochs(epochs) => {
if let Some(dataloader_size) = dataloader_size {
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
} else {
false
}
}
}
}
}
pub struct TrainingArguments<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize> {
loader: Box<dyn DataLoader<I, L>>,
eval_loader: Option<Box<dyn DataLoader<I, L>>>,
eval_strategy: EvaluationStrategy,
ckpt_strategy: CheckpointStrategy,
ckpt_path: PathBuf,
lr: f32,
max_saved_ckpts: usize,
gradient_accumulation_steps: usize,
max_steps: usize,
max_eval_steps: usize
}
impl<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>
TrainingArguments<I, L, NI, NL>
{
pub fn new<D: DataLoader<I, L> + 'static>(train_loader: D) -> Self {
Self {
loader: Box::new(train_loader),
eval_loader: None,
eval_strategy: EvaluationStrategy::None,
ckpt_strategy: CheckpointStrategy::Epochs(1),
ckpt_path: PathBuf::from("checkpoints"),
lr: 1e-4,
gradient_accumulation_steps: 1,
max_saved_ckpts: 1,
max_steps: usize::MAX,
max_eval_steps: usize::MAX
}
}
pub fn with_lr(mut self, lr: f32) -> Self {
self.lr = lr;
self
}
pub fn with_max_steps(mut self, steps: usize) -> Self {
self.max_steps = steps;
self
}
pub fn with_max_eval_steps(mut self, steps: usize) -> Self {
self.max_eval_steps = steps;
self
}
pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
self.gradient_accumulation_steps = steps;
self
}
pub fn with_ckpt_path(mut self, path: impl Into<PathBuf>) -> Self {
self.ckpt_path = path.into();
self
}
pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self {
self.ckpt_strategy = strategy;
self
}
pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self {
self.max_saved_ckpts = max_ckpts;
self
}
pub fn with_eval_loader<D: DataLoader<I, L> + 'static>(mut self, eval_loader: D) -> Self {
self.eval_loader = Some(Box::new(eval_loader));
self
}
pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self {
self.eval_strategy = strategy;
self
}
}
impl super::Trainer {
pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
&self,
mut args: TrainingArguments<I, L, NI, NL>
) -> crate::Result<()> {
let optimizer = self.optimizer();
optimizer.set_lr(args.lr)?;
let mut saved_ckpts = VecDeque::new();
let mut global_step = 0;
for (iter_step, _) in (0..args.max_steps).enumerate() {
let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX);
let (inputs, labels) = args.loader.load(iter_step)?;
let (inputs, labels) = (inputs.into(), labels.into());
let outputs = self.step(inputs, labels)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
println!("epoch={epoch} step={global_step} loss={loss}");
if iter_step % args.gradient_accumulation_steps == 0 {
optimizer.step()?;
optimizer.reset_grad()?;
global_step += 1;
}
if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) {
if !args.ckpt_path.exists() {
let _ = fs::create_dir_all(&args.ckpt_path);
}
let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt"));
self.checkpoint().save(&ckpt_path, true)?;
saved_ckpts.push_front(ckpt_path.clone());
while saved_ckpts.len() > args.max_saved_ckpts {
let Some(old_ckpt) = saved_ckpts.pop_back() else {
break;
};
let _ = fs::remove_file(old_ckpt);
}
}
if args
.eval_strategy
.should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
{
let eval_loss = self.eval_inner(&mut args)?;
println!("eval_loss={eval_loss}");
}
}
Ok(())
}
pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
&self,
args: &mut TrainingArguments<I, L, NI, NL>
) -> crate::Result<f32> {
let Some(eval_loader) = &mut args.eval_loader else {
return Ok(0.0);
};
let mut total_loss = 0.0;
for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
let (inputs, labels) = eval_loader.load(step)?;
let (inputs, labels) = (inputs.into(), labels.into());
let outputs = self.eval_step(inputs, labels)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
}
Ok(total_loss)
}
}

235
src/training/trainer.rs Normal file
View File

@@ -0,0 +1,235 @@
use std::{
ffi::CString,
path::Path,
ptr::{self, NonNull},
sync::Arc
};
use ort_sys::c_char;
use super::{trainsys, Checkpoint, Optimizer};
use crate::{
char_p_to_string,
error::{assert_non_null_pointer, status_to_result},
Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value
};
#[derive(Debug)]
pub struct Trainer {
pub(crate) ptr: NonNull<ort_sys::OrtTrainingSession>,
train_output_names: Vec<String>,
optimizer: Optimizer,
ckpt: Checkpoint,
_allocator: Allocator
}
impl Trainer {
pub fn new(
session_options: SessionBuilder,
allocator: Allocator,
ckpt: Checkpoint,
training_model_path: impl AsRef<Path>,
eval_model_path: impl AsRef<Path>,
optimizer_model_path: impl AsRef<Path>
) -> Result<Self> {
let training_model_path = crate::util::path_to_os_char(training_model_path);
let eval_model_path = crate::util::path_to_os_char(eval_model_path);
let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path);
let env = crate::get_environment()?;
let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut();
trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)];
let ptr = unsafe { NonNull::new_unchecked(ptr) };
let mut train_output_len = 0;
trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession];
let train_output_names = (0..train_output_len)
.map(|i| {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession];
let name = match char_p_to_string(name_bytes) {
Ok(name) => name,
Err(e) => {
unsafe { allocator.free(name_bytes) };
return Err(e);
}
};
unsafe { allocator.free(name_bytes) };
Ok(name)
})
.collect::<Result<Vec<String>>>()?;
Ok(Self {
ptr,
_allocator: allocator,
train_output_names,
optimizer: Optimizer(ptr),
ckpt
})
}
pub fn new_from_artifacts(
session_options: SessionBuilder,
allocator: Allocator,
base_dir: impl AsRef<Path>,
override_ckpt: Option<Checkpoint>
) -> Result<Self> {
let base_dir = base_dir.as_ref();
let ckpt = if let Some(ckpt) = override_ckpt {
ckpt
} else {
Checkpoint::load(base_dir.join("checkpoint"))?
};
Self::new(
session_options,
allocator,
ckpt,
base_dir.join("training_model.onnx"),
base_dir.join("eval_model.onnx"),
base_dir.join("optimizer_model.onnx")
)
}
pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>(
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'_, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
},
SessionInputs::ValueArray(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
},
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
}
}
fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r, 's>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();
let run_options_ptr = if let Some(run_options) = &run_options {
run_options.run_options_ptr.as_ptr()
} else {
std::ptr::null_mut()
};
trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun];
let outputs: Vec<Value> = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
// TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`.
// but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None)
})
.collect();
Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs))
}
pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>(
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'_, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
},
SessionInputs::ValueArray(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
},
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
}
}
fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r, 's>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();
let run_options_ptr = if let Some(run_options) = &run_options {
run_options.run_options_ptr.as_ptr()
} else {
std::ptr::null_mut()
};
trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun];
let outputs: Vec<Value> = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
// TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`.
// but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None)
})
.collect();
Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs))
}
pub fn export<O: AsRef<str>>(&self, out_path: impl AsRef<Path>, output_names: impl AsRef<[O]>) -> Result<()> {
let out_path = crate::util::path_to_os_char(out_path);
let output_names_ptr: Vec<*const c_char> = output_names
.as_ref()
.iter()
.map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!()))
.map(|n| n.into_raw().cast_const())
.collect();
let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())];
// Reconvert name ptrs to CString so drop impl is called and memory is freed
drop(
output_names_ptr
.into_iter()
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
})
.collect::<Result<Vec<_>>>()?
);
status_to_result(res).map_err(Error::CreateSession)?;
Ok(())
}
pub fn optimizer(&self) -> &Optimizer {
&self.optimizer
}
pub fn checkpoint(&self) -> &Checkpoint {
&self.ckpt
}
}
impl Drop for Trainer {
fn drop(&mut self) {
tracing::trace!("dropping trainer");
trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())];
}
}

26
src/util.rs Normal file
View File

@@ -0,0 +1,26 @@
#[cfg(not(target_family = "windows"))]
use std::os::raw::c_char;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
use std::{ffi::OsString, path::Path};
#[cfg(target_family = "windows")]
type OsCharArray = Vec<u16>;
#[cfg(not(target_family = "windows"))]
type OsCharArray = Vec<c_char>;
pub fn path_to_os_char(path: impl AsRef<Path>) -> OsCharArray {
let model_path = OsString::from(path.as_ref());
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path.encode_wide().chain(std::iter::once(0)).collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<c_char> = model_path
.as_encoded_bytes()
.iter()
.chain(std::iter::once(&b'\0'))
.map(|b| *b as c_char)
.collect();
model_path
}

4
tools/requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
torch~=2.3
torch-ort~=1.17
onnx~=1.16
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0

View File

@@ -0,0 +1,140 @@
import math
import onnx
from onnxruntime.training import artifacts
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class RMSNorm(nn.Module):
def __init__(self, dim: int, *, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
if x.dtype != torch.float32:
xf = x.to(dtype=torch.float32)
else:
xf = x
output = (xf * torch.sqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps))
if x.dtype != torch.float32:
output = output.to(dtype=x.dtype)
return output * self.weight
class RoPE(nn.Module):
def __init__(self, embedding_dim: int, *, max_seq_length: int = 2048, base: float = 10000.0):
super().__init__()
pe = torch.zeros(max_seq_length, embedding_dim)
position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, step=2).float() * (-math.log(base) / embedding_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe, persistent=False)
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
return x + self.pe[:, :x.shape[1], :]
class Attention(nn.Module):
def __init__(self, embedding_dim: int, *, rope: RoPE, max_seq_length: int = 2048, n_heads: int = 4):
super().__init__()
self.embedding_dim = embedding_dim
self.n_heads = n_heads
self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
self.proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
self.rope = rope
self.register_buffer('bias', torch.tril(torch.ones(max_seq_length, max_seq_length))[None, None, :, :], persistent=False)
def forward(self, x: Tensor) -> Tensor:
b, t, c = x.size()
x = self.rope(x)
q, k, v = self.qkv(x).split(self.embedding_dim, dim=2)
q = q.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2)
k = k.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2)
v = v.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(b, t, c)
return self.proj(y)
class FFN(nn.Module):
def __init__(self, embedding_dim: int, intermediate_dim: int | None = None):
super().__init__()
intermediate_dim = intermediate_dim or embedding_dim * 4
self.w1 = nn.Linear(embedding_dim, intermediate_dim * 2, bias=False)
self.w2 = nn.Linear(intermediate_dim, embedding_dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
x, gate = self.w1(x).chunk(2, dim=-1)
return self.w2(F.gelu(gate) * x)
class Layer(nn.Module):
def __init__(self, embedding_dim: int, rope: RoPE):
super().__init__()
self.attn = Attention(embedding_dim, rope=rope)
self.norm1 = RMSNorm(embedding_dim)
self.ffn = FFN(embedding_dim)
self.norm2 = RMSNorm(embedding_dim)
def forward(self, x: Tensor) -> Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class CLM(nn.Module):
def __init__(self, embedding_dim: int, n_layers: int, *, vocab_size: int):
super().__init__()
rope = RoPE(embedding_dim)
self.layers = nn.ModuleList([Layer(embedding_dim, rope=rope) for _ in range(n_layers)])
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.norm = RMSNorm(embedding_dim)
self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False)
def forward(self, x: Tensor) -> Tensor:
x = self.word_embeddings(x)
for layer in self.layers:
x = layer(x)
logits = self.lm_head(self.norm(x))
return logits.view(-1, logits.size(-1))
lm = CLM(256, 4, vocab_size=50257)
torch.onnx.export(
lm,
torch.randint(0, 50256, (1, 64)),
f'tools/train-data/mini-clm/model.onnx',
input_names=['input_ids'],
output_names=['probs'],
dynamic_axes={
'input_ids': {0: 'batch', 1: 'seq'},
'probs': {0: 'batch_seq'}
},
opset_version=14
)
onnx_model = onnx.load('tools/train-data/mini-clm/model.onnx')
requires_grad = [param.name for param in onnx_model.graph.initializer]
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
frozen_params=[],
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory='tools/train-data/mini-clm'
)