mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: training (#202)
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -186,6 +186,7 @@ WixTools/
|
||||
# ONNX Runtime downloaded models
|
||||
**/*.onnx
|
||||
**/*.ort
|
||||
**/*.pbseq
|
||||
!examples/webassembly/**/*.ort
|
||||
!tests/data/*.onnx
|
||||
!tests/data/*.ort
|
||||
@@ -196,4 +197,8 @@ WixTools/
|
||||
# Glassbench results
|
||||
/glassbench*.db
|
||||
|
||||
# Python virtual environment
|
||||
.venv*
|
||||
|
||||
# Training checkpoints
|
||||
tools/train-data/**/checkpoint
|
||||
|
||||
@@ -7,6 +7,7 @@ members = [
|
||||
'examples/model-info',
|
||||
'examples/yolov8',
|
||||
'examples/modnet',
|
||||
'examples/training',
|
||||
'examples/webassembly'
|
||||
]
|
||||
default-members = [
|
||||
@@ -45,13 +46,15 @@ strip = true
|
||||
codegen-units = 1
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
|
||||
features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
|
||||
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
|
||||
rustdoc-args = [ "--cfg", "docsrs" ]
|
||||
|
||||
[features]
|
||||
default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ]
|
||||
|
||||
training = [ "ort-sys/training" ]
|
||||
|
||||
operator-libraries = [ "libc", "winapi" ]
|
||||
|
||||
fetch-models = [ "ureq" ]
|
||||
|
||||
18
examples/training/Cargo.toml
Normal file
18
examples/training/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
publish = false
|
||||
name = "example-training"
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
ort = { path = "../../", features = [ "training" ] }
|
||||
ndarray = "0.15"
|
||||
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
|
||||
rand = "0.8"
|
||||
simd-json = "0.13"
|
||||
kdam = "0.5"
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
|
||||
|
||||
[features]
|
||||
load-dynamic = [ "ort/load-dynamic" ]
|
||||
cuda = [ "ort/cuda" ]
|
||||
26
examples/training/README.md
Normal file
26
examples/training/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Training Examples
|
||||
|
||||
## `train-clm`
|
||||
This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device.
|
||||
|
||||
To get started, create a Python virtual environment and install the following packages:
|
||||
```
|
||||
pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3
|
||||
```
|
||||
|
||||
We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts.
|
||||
|
||||
Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo.
|
||||
|
||||
Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate.
|
||||
|
||||
While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text:
|
||||
```
|
||||
100%|██████████████████████████████████████| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611]
|
||||
I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|>
|
||||
```
|
||||
|
||||
Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute.
|
||||
|
||||
## `train-clm-simple`
|
||||
This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to 🤗 Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you!
|
||||
5
examples/training/build.rs
Normal file
5
examples/training/build.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
fn main() {
|
||||
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
|
||||
#[cfg(target_os = "macos")]
|
||||
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
|
||||
}
|
||||
44
examples/training/examples/pretokenize.rs
Normal file
44
examples/training/examples/pretokenize.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use std::{
|
||||
env,
|
||||
fs::File,
|
||||
io::{BufRead, BufReader, BufWriter, Write},
|
||||
path::Path
|
||||
};
|
||||
|
||||
use simd_json::derived::ValueObjectAccessAsScalar;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const MAX_TOKENS: usize = 500_000;
|
||||
|
||||
fn main() {
|
||||
let input = env::args().nth(1).expect("provide input jsonl");
|
||||
let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into());
|
||||
|
||||
let input = BufReader::new(File::open(input).unwrap());
|
||||
let mut output = BufWriter::new(File::create(output).unwrap());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(
|
||||
Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("gpt2")
|
||||
.join("data")
|
||||
.join("tokenizer.json")
|
||||
)
|
||||
.unwrap();
|
||||
let mut bytes_written = 0;
|
||||
|
||||
for line in input.lines() {
|
||||
let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() };
|
||||
let tokenized = tokenizer
|
||||
.encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false)
|
||||
.unwrap();
|
||||
let id_bytes: Vec<u8> = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect();
|
||||
output.write_all(&id_bytes).unwrap();
|
||||
bytes_written += id_bytes.len();
|
||||
if bytes_written >= MAX_TOKENS * 2 {
|
||||
output.flush().unwrap();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
118
examples/training/examples/train-clm-simple.rs
Normal file
118
examples/training/examples/train-clm-simple.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{Read, Seek, SeekFrom, Write},
|
||||
path::Path
|
||||
};
|
||||
|
||||
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
|
||||
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments};
|
||||
use rand::RngCore;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const BATCH_SIZE: usize = 16;
|
||||
const SEQUENCE_LENGTH: usize = 256;
|
||||
|
||||
fn main() -> ort::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
ort::init().commit()?;
|
||||
|
||||
let trainer = Trainer::new_from_artifacts(
|
||||
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
|
||||
Allocator::default(),
|
||||
"tools/train-data/mini-clm",
|
||||
None
|
||||
)?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(
|
||||
Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("gpt2")
|
||||
.join("data")
|
||||
.join("tokenizer.json")
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut dataset = File::open("dataset.bin").unwrap();
|
||||
let file_size = dataset.metadata().unwrap().len();
|
||||
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
|
||||
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
|
||||
let dataloader = move |_: usize| {
|
||||
for batch in 0..BATCH_SIZE {
|
||||
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
|
||||
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
|
||||
dataset
|
||||
.read_exact(unsafe {
|
||||
std::slice::from_raw_parts_mut(
|
||||
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
|
||||
.as_mut_ptr()
|
||||
.cast::<u8>(),
|
||||
SEQUENCE_LENGTH * 2
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
|
||||
dataset
|
||||
.read_exact(unsafe {
|
||||
std::slice::from_raw_parts_mut(
|
||||
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
|
||||
.as_mut_ptr()
|
||||
.cast::<u8>(),
|
||||
SEQUENCE_LENGTH * 2
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok((
|
||||
ort::inputs![Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?,
|
||||
ort::inputs![Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?
|
||||
))
|
||||
};
|
||||
|
||||
trainer.train(
|
||||
TrainingArguments::new(dataloader)
|
||||
.with_lr(7e-5)
|
||||
.with_max_steps(5000)
|
||||
.with_ckpt_strategy(CheckpointStrategy::Steps(500))
|
||||
)?;
|
||||
|
||||
trainer.export("trained-clm.onnx", ["probs"])?;
|
||||
|
||||
let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
|
||||
|
||||
let mut stdout = std::io::stdout();
|
||||
|
||||
let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
|
||||
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
|
||||
|
||||
let mut tokens = Array1::from_iter(tokens.iter().cloned());
|
||||
|
||||
for _ in 0..50 {
|
||||
let array = tokens.view().insert_axis(Axis(0));
|
||||
let outputs = session.run(ort::inputs![array]?)?;
|
||||
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;
|
||||
|
||||
let probabilities = &mut generated_tokens
|
||||
.slice(s![-1, ..])
|
||||
.to_owned()
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.collect::<Vec<_>>();
|
||||
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
|
||||
|
||||
let token = probabilities[0].0;
|
||||
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];
|
||||
|
||||
let token_str = tokenizer.decode(&[token as _], false).unwrap();
|
||||
print!("{}", token_str);
|
||||
stdout.flush().unwrap();
|
||||
}
|
||||
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
133
examples/training/examples/train-clm.rs
Normal file
133
examples/training/examples/train-clm.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{Read, Seek, SeekFrom, Write},
|
||||
path::Path
|
||||
};
|
||||
|
||||
use kdam::BarExt;
|
||||
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
|
||||
use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer};
|
||||
use rand::RngCore;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const BATCH_SIZE: usize = 16;
|
||||
const SEQUENCE_LENGTH: usize = 256;
|
||||
|
||||
fn main() -> ort::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
ort::init().commit()?;
|
||||
|
||||
kdam::term::init(true);
|
||||
let _ = kdam::term::hide_cursor();
|
||||
|
||||
let trainer = Trainer::new(
|
||||
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
|
||||
Allocator::default(),
|
||||
Checkpoint::load("tools/train-data/mini-clm/checkpoint")?,
|
||||
"tools/train-data/mini-clm/training_model.onnx",
|
||||
"tools/train-data/mini-clm/eval_model.onnx",
|
||||
"tools/train-data/mini-clm/optimizer_model.onnx"
|
||||
)?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(
|
||||
Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("gpt2")
|
||||
.join("data")
|
||||
.join("tokenizer.json")
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let optimizer = trainer.optimizer();
|
||||
optimizer.set_lr(7e-5)?;
|
||||
|
||||
let mut dataset = File::open("dataset.bin").unwrap();
|
||||
let file_size = dataset.metadata().unwrap().len();
|
||||
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
|
||||
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
|
||||
let mut pb = kdam::tqdm!(total = 5000);
|
||||
for _ in 0..5000 {
|
||||
for batch in 0..BATCH_SIZE {
|
||||
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
|
||||
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
|
||||
dataset
|
||||
.read_exact(unsafe {
|
||||
std::slice::from_raw_parts_mut(
|
||||
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
|
||||
.as_mut_ptr()
|
||||
.cast::<u8>(),
|
||||
SEQUENCE_LENGTH * 2
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
|
||||
dataset
|
||||
.read_exact(unsafe {
|
||||
std::slice::from_raw_parts_mut(
|
||||
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
|
||||
.as_mut_ptr()
|
||||
.cast::<u8>(),
|
||||
SEQUENCE_LENGTH * 2
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let inputs = Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap();
|
||||
let labels = Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap();
|
||||
|
||||
let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
pb.set_postfix(format!("loss={loss:.3}"));
|
||||
pb.update(1).unwrap();
|
||||
if loss.is_nan() {
|
||||
return Ok(());
|
||||
}
|
||||
optimizer.step()?;
|
||||
optimizer.reset_grad()?;
|
||||
}
|
||||
|
||||
eprintln!();
|
||||
let _ = kdam::term::show_cursor();
|
||||
|
||||
trainer.export("trained-clm.onnx", ["probs"])?;
|
||||
|
||||
let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
|
||||
|
||||
let mut stdout = std::io::stdout();
|
||||
|
||||
let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
|
||||
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
|
||||
|
||||
let mut tokens = Array1::from_iter(tokens.iter().cloned());
|
||||
|
||||
for _ in 0..50 {
|
||||
let array = tokens.view().insert_axis(Axis(0));
|
||||
let outputs = session.run(ort::inputs![array]?)?;
|
||||
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;
|
||||
|
||||
let probabilities = &mut generated_tokens
|
||||
.slice(s![-1, ..])
|
||||
.to_owned()
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.collect::<Vec<_>>();
|
||||
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
|
||||
|
||||
let token = probabilities[0].0;
|
||||
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];
|
||||
|
||||
let token_str = tokenizer.decode(&[token as _], false).unwrap();
|
||||
print!("{}", token_str);
|
||||
stdout.flush().unwrap();
|
||||
}
|
||||
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
@@ -16,6 +16,7 @@ include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
training = []
|
||||
download-binaries = [ "ureq", "tar", "flate2", "sha2" ]
|
||||
load-dynamic = []
|
||||
copy-dylibs = []
|
||||
|
||||
@@ -37,12 +37,12 @@ fn fetch_file(source_url: &str) -> Vec<u8> {
|
||||
buffer
|
||||
}
|
||||
|
||||
fn find_dist(target: &str, designator: &str) -> Option<(&'static str, &'static str)> {
|
||||
fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> {
|
||||
DIST_TABLE
|
||||
.split('\n')
|
||||
.filter(|c| !c.is_empty() && !c.starts_with('#'))
|
||||
.map(|c| c.split('\t').collect::<Vec<_>>())
|
||||
.find(|c| c[0] == designator && c[1] == target)
|
||||
.find(|c| c[0] == feature_set && c[1] == target)
|
||||
.map(|c| (c[2], c[3]))
|
||||
}
|
||||
|
||||
@@ -341,23 +341,31 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
|
||||
#[cfg(feature = "download-binaries")]
|
||||
{
|
||||
let target = env::var("TARGET").unwrap().to_string();
|
||||
let designator = if cfg!(any(feature = "cuda", feature = "tensorrt")) {
|
||||
if lib_exists("cudart64_12.dll") || lib_exists("libcudart.so.12") { "cu12" } else { "cu11" }
|
||||
|
||||
let mut feature_set = Vec::new();
|
||||
if cfg!(feature = "training") {
|
||||
feature_set.push("train");
|
||||
}
|
||||
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
|
||||
if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") {
|
||||
feature_set.push("cu11");
|
||||
} else {
|
||||
feature_set.push("cu12");
|
||||
}
|
||||
} else if cfg!(feature = "rocm") {
|
||||
"rocm"
|
||||
} else {
|
||||
"none"
|
||||
};
|
||||
let mut dist = find_dist(&target, designator);
|
||||
if dist.is_none() && designator != "none" {
|
||||
feature_set.push("rocm");
|
||||
}
|
||||
let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() };
|
||||
let mut dist = find_dist(&target, &feature_set);
|
||||
if dist.is_none() && feature_set != "none" {
|
||||
dist = find_dist(&target, "none");
|
||||
}
|
||||
|
||||
if dist.is_none() {
|
||||
panic!(
|
||||
"downloaded binaries not available for target {target}{}\nyou may have to compile ONNX Runtime from source",
|
||||
if designator != "none" {
|
||||
format!(" (note: also requested `{designator}`)")
|
||||
if feature_set != "none" {
|
||||
format!(" (note: also requested features `{feature_set}`)")
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
|
||||
@@ -4,12 +4,26 @@ cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/
|
||||
rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz 84F74428E0BEC68C55B8E1E91B9282E984CD2866148A2584382B8CB3284214A3
|
||||
none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-unknown-linux-gnu.tgz 0A193706A95286853D792D7D9B2271CBEA35C57F249943FE811CED97E0E24862
|
||||
|
||||
train aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-unknown-linux-gnu.tgz C04DBEAF19F2BCD3643F8F7D7FA01110A1AF429DFDD1C1DC7C5EDA2B1A8AA324
|
||||
train,cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-unknown-linux-gnu.tgz A139D8AD8930930F5A61DF112C8275AAD1F0415FAFD08CE3031CEFFFC30F2445
|
||||
train,cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-unknown-linux-gnu.tgz 2DAA2E2CF44E9B9A96AB2E9C4271C35189C96BF264D1797DABCF1D6711730DE7
|
||||
train,rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz DD373BA6B251D21953223B2FBB64F4DF34CFE98A63C26D16607BEAC6BC788466
|
||||
train x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-unknown-linux-gnu.tgz 0E617970AE83ABE5FB9A3D5D69AAC9A67ED4C9494AD527B14A84FDC98CA9B924
|
||||
|
||||
none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-pc-windows-msvc.tgz B2F962F0E75F17F3D657B3504CE891BAA6461B26AF65FBD9244B3CCA17FD79D4
|
||||
cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12-v1.18.1-x86_64-pc-windows-msvc.tgz CDBC2D87B202E1847900E94796D102EE4D5C19A9568BBD014838ECD1F5D5350B
|
||||
cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu11-v1.18.1-x86_64-pc-windows-msvc.tgz B514FC25453F955F8592100448B27F5E1762A344E8C2D57D41B908978EF2A126
|
||||
none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-pc-windows-msvc.tgz EB2BCD1778C5934437D4C5B17F67DEAF5F67D2E3C18C7298973EACD41113DC01
|
||||
|
||||
train aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-pc-windows-msvc.tgz 8CC1FFFD8AB5E526A076C29A767A650C436E31179D0C6E52C2EA936067B72566
|
||||
train,cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-pc-windows-msvc.tgz 6AF64567E25B59AD1196D4953EF8C6A65795E8A4B864E10D8303A027AC50B2D0
|
||||
train,cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-pc-windows-msvc.tgz E14AA0F4FBBBCAF925AD4DB4F76B06402F654B36C5F221E00010D1005F47AE56
|
||||
train x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-pc-windows-msvc.tgz 84728438E5A950027EBBDC51463F4E5B99B4979087F0F127EA18BC604507E979
|
||||
|
||||
none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-apple-darwin.tgz B42BE76AFB9495983A6D5D498D56D5E685B018F1011EF4C5B8C56124B192FD37
|
||||
none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-apple-darwin.tgz 247F73A5B3665A6660DFB35213E6FEAAC6ED6CAC5816DD85A348DF790F60A30B
|
||||
|
||||
train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-apple-darwin.tgz 29DC09AFA5C3619CF3125F3D55DD64E5EE64451D6BD0044527776849AADEE344
|
||||
train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-apple-darwin.tgz 898EC9E3F852843ECDB618CF8E317F4C92BDEB33FC773038960857BCB37CB347
|
||||
|
||||
none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF
|
||||
|
||||
@@ -823,9 +823,117 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() {
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct OrtTrainingApi {
|
||||
pub struct OrtTrainingSession {
|
||||
_unused: [u8; 0]
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct OrtCheckpointState {
|
||||
_unused: [u8; 0]
|
||||
}
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
|
||||
pub enum OrtPropertyType {
|
||||
OrtIntProperty = 0,
|
||||
OrtFloatProperty = 1,
|
||||
OrtStringProperty = 2
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct OrtTrainingApi {
|
||||
pub LoadCheckpoint:
|
||||
::std::option::Option<_system!(unsafe fn(checkpoint_path: *const ortchar, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr)>,
|
||||
pub SaveCheckpoint: ::std::option::Option<
|
||||
_system!(unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ortchar, include_optimizer_state: bool) -> OrtStatusPtr)
|
||||
>,
|
||||
pub CreateTrainingSession: ::std::option::Option<
|
||||
_system!(
|
||||
unsafe fn(
|
||||
env: *const OrtEnv,
|
||||
options: *const OrtSessionOptions,
|
||||
checkpoint_state: *mut OrtCheckpointState,
|
||||
train_model_path: *const ortchar,
|
||||
eval_model_path: *const ortchar,
|
||||
optimizer_model_path: *const ortchar,
|
||||
out: *mut *mut OrtTrainingSession
|
||||
) -> OrtStatusPtr
|
||||
)
|
||||
>,
|
||||
pub CreateTrainingSessionFromBuffer: ::std::option::Option<
|
||||
_system!(
|
||||
unsafe fn(
|
||||
env: *const OrtEnv,
|
||||
options: *const OrtSessionOptions,
|
||||
checkpoint_state: *mut OrtCheckpointState,
|
||||
train_model_data: *const (),
|
||||
train_data_length: size_t,
|
||||
eval_model_data: *const (),
|
||||
eval_data_length: size_t,
|
||||
optimizer_model_data: *const (),
|
||||
optimizer_data_length: size_t,
|
||||
out: *mut *mut OrtTrainingSession
|
||||
) -> OrtStatusPtr
|
||||
)
|
||||
>,
|
||||
pub TrainingSessionGetTrainingModelOutputCount:
|
||||
::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>,
|
||||
pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>,
|
||||
pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option<
|
||||
_system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr)
|
||||
>,
|
||||
pub TrainingSessionGetEvalModelOutputName: ::std::option::Option<
|
||||
_system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr)
|
||||
>,
|
||||
pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>,
|
||||
pub TrainStep: ::std::option::Option<
|
||||
_system!(
|
||||
unsafe fn(
|
||||
session: *mut OrtTrainingSession,
|
||||
run_options: *const OrtRunOptions,
|
||||
inputs_len: size_t,
|
||||
inputs: *const *const OrtValue,
|
||||
outputs_len: size_t,
|
||||
outputs: *mut *mut OrtValue
|
||||
) -> OrtStatusPtr
|
||||
)
|
||||
>,
|
||||
pub EvalStep: ::std::option::Option<
|
||||
_system!(
|
||||
unsafe fn(
|
||||
session: *mut OrtTrainingSession,
|
||||
run_options: *const OrtRunOptions,
|
||||
inputs_len: size_t,
|
||||
inputs: *const *const OrtValue,
|
||||
outputs_len: size_t,
|
||||
outputs: *mut *mut OrtValue
|
||||
) -> OrtStatusPtr
|
||||
)
|
||||
>,
|
||||
pub SetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: f32) -> OrtStatusPtr)>,
|
||||
pub GetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: *mut f32) -> OrtStatusPtr)>,
|
||||
pub OptimizerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, run_options: *const OrtRunOptions) -> OrtStatusPtr)>,
|
||||
pub RegisterLinearLRScheduler: ::std::option::Option<
|
||||
_system!(unsafe fn(session: *mut OrtTrainingSession, warmup_step_count: i64, total_step_count: i64, initial_lr: f32) -> OrtStatusPtr)
|
||||
>,
|
||||
pub SchedulerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>,
|
||||
pub GetParametersSize: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, out: *mut size_t, trainable_only: bool) -> OrtStatusPtr)>,
|
||||
pub CopyParametersToBuffer:
|
||||
::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>,
|
||||
pub CopyBufferToParameters:
|
||||
::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>,
|
||||
pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>,
|
||||
pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))>,
|
||||
pub ExportModelForInferencing: ::std::option::Option<
|
||||
_system!(
|
||||
unsafe fn(
|
||||
session: *mut OrtTrainingSession,
|
||||
inference_model_path: *const ortchar,
|
||||
graph_outputs_len: usize,
|
||||
graph_output_names: *const *const c_char
|
||||
) -> OrtStatusPtr
|
||||
)
|
||||
>
|
||||
}
|
||||
#[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"]
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
|
||||
@@ -261,7 +261,9 @@ pub enum Error {
|
||||
#[error("Could't get `AllocatorType` from memory info: {0}")]
|
||||
GetAllocatorType(ErrorInternal),
|
||||
#[error("Could't get device ID from memory info: {0}")]
|
||||
GetDeviceId(ErrorInternal)
|
||||
GetDeviceId(ErrorInternal),
|
||||
#[error("Training API is not enabled in this build of ONNX Runtime.")]
|
||||
TrainingNotEnabled
|
||||
}
|
||||
|
||||
impl Error {
|
||||
|
||||
@@ -23,6 +23,9 @@ pub(crate) mod metadata;
|
||||
pub(crate) mod operator;
|
||||
pub(crate) mod session;
|
||||
pub(crate) mod tensor;
|
||||
#[cfg(feature = "training")]
|
||||
pub(crate) mod training;
|
||||
pub(crate) mod util;
|
||||
pub(crate) mod value;
|
||||
#[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -66,6 +69,9 @@ pub use self::session::{
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
|
||||
pub use self::tensor::ArrayExtensions;
|
||||
pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data};
|
||||
#[cfg(feature = "training")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
|
||||
pub use self::training::*;
|
||||
pub use self::value::{
|
||||
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
|
||||
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
#[cfg(any(feature = "operator-libraries", not(windows)))]
|
||||
use std::ffi::CString;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
#[cfg(target_family = "windows")]
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::path::Path;
|
||||
#[cfg(feature = "fetch-models")]
|
||||
@@ -316,20 +312,7 @@ impl SessionBuilder {
|
||||
});
|
||||
}
|
||||
|
||||
// Build an OsString, then a vector of bytes to pass to C
|
||||
let model_path = std::ffi::OsString::from(model_filepath);
|
||||
#[cfg(target_family = "windows")]
|
||||
let model_path: Vec<u16> = model_path
|
||||
.encode_wide()
|
||||
.chain(std::iter::once(0)) // Make sure we have a null terminated string
|
||||
.collect();
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
let model_path: Vec<std::os::raw::c_char> = model_path
|
||||
.as_encoded_bytes()
|
||||
.iter()
|
||||
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
|
||||
.map(|b| *b as std::os::raw::c_char)
|
||||
.collect();
|
||||
let model_path = crate::util::path_to_os_char(model_filepath);
|
||||
|
||||
let env = get_environment()?;
|
||||
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;
|
||||
|
||||
142
src/training/mod.rs
Normal file
142
src/training/mod.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::{
|
||||
path::Path,
|
||||
ptr::{self, NonNull},
|
||||
sync::{
|
||||
atomic::{AtomicPtr, Ordering},
|
||||
OnceLock
|
||||
}
|
||||
};
|
||||
|
||||
use crate::{ortsys, Error, Result, RunOptions};
|
||||
|
||||
mod simple;
|
||||
mod trainer;
|
||||
|
||||
pub use self::{
|
||||
simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments},
|
||||
trainer::Trainer
|
||||
};
|
||||
|
||||
pub(crate) static TRAINING_API: OnceLock<AtomicPtr<ort_sys::OrtTrainingApi>> = OnceLock::new();
|
||||
|
||||
/// Returns a pointer to the global [`ort_sys::OrtTrainingApi`] object, or errors if the Training API is not enabled.
|
||||
///
|
||||
/// # Panics
|
||||
/// May panic if:
|
||||
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
|
||||
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
|
||||
pub fn training_api() -> Result<NonNull<ort_sys::OrtTrainingApi>> {
|
||||
NonNull::new(
|
||||
TRAINING_API
|
||||
.get_or_init(|| {
|
||||
let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)];
|
||||
AtomicPtr::new(training_api.cast_mut())
|
||||
})
|
||||
.load(Ordering::Relaxed)
|
||||
)
|
||||
.ok_or(Error::TrainingNotEnabled)
|
||||
}
|
||||
|
||||
macro_rules! trainsys {
|
||||
($method:ident) => {
|
||||
$crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))
|
||||
};
|
||||
(unsafe $method:ident) => {
|
||||
unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) }
|
||||
};
|
||||
($method:ident($($n:expr),+ $(,)?)) => {
|
||||
$crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)
|
||||
};
|
||||
(unsafe $method:ident($($n:expr),+ $(,)?)) => {
|
||||
unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }
|
||||
};
|
||||
($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
|
||||
$crate::error::status_to_result($crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e)
|
||||
};
|
||||
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
|
||||
$crate::error::status_to_result(unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e)
|
||||
};
|
||||
($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
|
||||
$crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+);
|
||||
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
|
||||
};
|
||||
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{
|
||||
let _x = unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) };
|
||||
$($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+
|
||||
_x
|
||||
}};
|
||||
($method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
|
||||
$crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?;
|
||||
};
|
||||
(unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
|
||||
$crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?;
|
||||
};
|
||||
($method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
|
||||
$crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?;
|
||||
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
|
||||
};
|
||||
(unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
|
||||
$crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?;
|
||||
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
|
||||
}};
|
||||
}
|
||||
pub(crate) use trainsys;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Checkpoint {
|
||||
pub(crate) ptr: NonNull<ort_sys::OrtCheckpointState>
|
||||
}
|
||||
|
||||
impl Checkpoint {
|
||||
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = crate::util::path_to_os_char(path);
|
||||
let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut();
|
||||
trainsys![unsafe LoadCheckpoint(path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)];
|
||||
Ok(Checkpoint {
|
||||
ptr: unsafe { NonNull::new_unchecked(ptr) }
|
||||
})
|
||||
}
|
||||
|
||||
pub fn save(&self, path: impl AsRef<Path>, include_optimizer_state: bool) -> Result<()> {
|
||||
let path = crate::util::path_to_os_char(path);
|
||||
trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state) -> Error::CreateSession];
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Checkpoint {
|
||||
fn drop(&mut self) {
|
||||
tracing::trace!("dropping checkpoint");
|
||||
trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())];
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Optimizer(NonNull<ort_sys::OrtTrainingSession>);
|
||||
|
||||
impl Optimizer {
|
||||
pub fn reset_grad(&self) -> Result<()> {
|
||||
trainsys![unsafe LazyResetGrad(self.0.as_ptr()) -> Error::CreateSession];
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn lr(&self) -> Result<f32> {
|
||||
let mut lr = f32::NAN;
|
||||
trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr) -> Error::CreateSession];
|
||||
Ok(lr)
|
||||
}
|
||||
|
||||
pub fn set_lr(&self, lr: f32) -> Result<()> {
|
||||
trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr) -> Error::CreateSession];
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn step(&self) -> Result<()> {
|
||||
self.step_with_options(RunOptions::new()?)
|
||||
}
|
||||
|
||||
pub fn step_with_options(&self, options: RunOptions) -> Result<()> {
|
||||
trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.run_options_ptr.as_ptr()) -> Error::CreateSession];
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
240
src/training/simple.rs
Normal file
240
src/training/simple.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use std::{collections::VecDeque, fs, path::PathBuf};
|
||||
|
||||
use crate::{Result, SessionInputs};
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub trait DataLoader<I, L> {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)>;
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IterableDataLoader<T, I, L, C: Fn(&T) -> Result<(I, L)>> {
|
||||
items: Box<[T]>,
|
||||
collator: C
|
||||
}
|
||||
|
||||
impl<T, I, L, C: Fn(&T) -> Result<(I, L)>> DataLoader<I, L> for IterableDataLoader<T, I, L, C> {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)> {
|
||||
(self.collator)(&self.items[idx])
|
||||
}
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
Some(self.items.len())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iterable_data_loader<T, I, L, C: Fn(&T) -> Result<(I, L)>>(iterable: impl Iterator<Item = T>, collator: C) -> IterableDataLoader<T, I, L, C> {
|
||||
IterableDataLoader {
|
||||
items: iterable.collect::<Vec<T>>().into_boxed_slice(),
|
||||
collator
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, L, F: FnMut(usize) -> Result<(I, L)>> DataLoader<I, L> for F {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)> {
|
||||
(self)(idx)
|
||||
}
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub enum EvaluationStrategy {
|
||||
None,
|
||||
Steps(usize),
|
||||
Epochs(usize)
|
||||
}
|
||||
|
||||
impl EvaluationStrategy {
|
||||
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
|
||||
Self::Epochs(epochs) => {
|
||||
if let Some(dataloader_size) = dataloader_size {
|
||||
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum CheckpointStrategy {
|
||||
None,
|
||||
Steps(usize),
|
||||
Epochs(usize)
|
||||
}
|
||||
|
||||
impl CheckpointStrategy {
|
||||
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
|
||||
Self::Epochs(epochs) => {
|
||||
if let Some(dataloader_size) = dataloader_size {
|
||||
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TrainingArguments<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize> {
|
||||
loader: Box<dyn DataLoader<I, L>>,
|
||||
eval_loader: Option<Box<dyn DataLoader<I, L>>>,
|
||||
eval_strategy: EvaluationStrategy,
|
||||
ckpt_strategy: CheckpointStrategy,
|
||||
ckpt_path: PathBuf,
|
||||
lr: f32,
|
||||
max_saved_ckpts: usize,
|
||||
gradient_accumulation_steps: usize,
|
||||
max_steps: usize,
|
||||
max_eval_steps: usize
|
||||
}
|
||||
|
||||
impl<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>
|
||||
TrainingArguments<I, L, NI, NL>
|
||||
{
|
||||
pub fn new<D: DataLoader<I, L> + 'static>(train_loader: D) -> Self {
|
||||
Self {
|
||||
loader: Box::new(train_loader),
|
||||
eval_loader: None,
|
||||
eval_strategy: EvaluationStrategy::None,
|
||||
ckpt_strategy: CheckpointStrategy::Epochs(1),
|
||||
ckpt_path: PathBuf::from("checkpoints"),
|
||||
lr: 1e-4,
|
||||
gradient_accumulation_steps: 1,
|
||||
max_saved_ckpts: 1,
|
||||
max_steps: usize::MAX,
|
||||
max_eval_steps: usize::MAX
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_lr(mut self, lr: f32) -> Self {
|
||||
self.lr = lr;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_steps(mut self, steps: usize) -> Self {
|
||||
self.max_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_eval_steps(mut self, steps: usize) -> Self {
|
||||
self.max_eval_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
|
||||
self.gradient_accumulation_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ckpt_path(mut self, path: impl Into<PathBuf>) -> Self {
|
||||
self.ckpt_path = path.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self {
|
||||
self.ckpt_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self {
|
||||
self.max_saved_ckpts = max_ckpts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_eval_loader<D: DataLoader<I, L> + 'static>(mut self, eval_loader: D) -> Self {
|
||||
self.eval_loader = Some(Box::new(eval_loader));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self {
|
||||
self.eval_strategy = strategy;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl super::Trainer {
|
||||
pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
&self,
|
||||
mut args: TrainingArguments<I, L, NI, NL>
|
||||
) -> crate::Result<()> {
|
||||
let optimizer = self.optimizer();
|
||||
optimizer.set_lr(args.lr)?;
|
||||
|
||||
let mut saved_ckpts = VecDeque::new();
|
||||
let mut global_step = 0;
|
||||
for (iter_step, _) in (0..args.max_steps).enumerate() {
|
||||
let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX);
|
||||
let (inputs, labels) = args.loader.load(iter_step)?;
|
||||
let (inputs, labels) = (inputs.into(), labels.into());
|
||||
|
||||
let outputs = self.step(inputs, labels)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
println!("epoch={epoch} step={global_step} loss={loss}");
|
||||
|
||||
if iter_step % args.gradient_accumulation_steps == 0 {
|
||||
optimizer.step()?;
|
||||
optimizer.reset_grad()?;
|
||||
global_step += 1;
|
||||
}
|
||||
|
||||
if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) {
|
||||
if !args.ckpt_path.exists() {
|
||||
let _ = fs::create_dir_all(&args.ckpt_path);
|
||||
}
|
||||
|
||||
let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt"));
|
||||
self.checkpoint().save(&ckpt_path, true)?;
|
||||
|
||||
saved_ckpts.push_front(ckpt_path.clone());
|
||||
while saved_ckpts.len() > args.max_saved_ckpts {
|
||||
let Some(old_ckpt) = saved_ckpts.pop_back() else {
|
||||
break;
|
||||
};
|
||||
let _ = fs::remove_file(old_ckpt);
|
||||
}
|
||||
}
|
||||
|
||||
if args
|
||||
.eval_strategy
|
||||
.should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
|
||||
{
|
||||
let eval_loss = self.eval_inner(&mut args)?;
|
||||
println!("eval_loss={eval_loss}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
&self,
|
||||
args: &mut TrainingArguments<I, L, NI, NL>
|
||||
) -> crate::Result<f32> {
|
||||
let Some(eval_loader) = &mut args.eval_loader else {
|
||||
return Ok(0.0);
|
||||
};
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
|
||||
let (inputs, labels) = eval_loader.load(step)?;
|
||||
let (inputs, labels) = (inputs.into(), labels.into());
|
||||
|
||||
let outputs = self.eval_step(inputs, labels)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
|
||||
}
|
||||
|
||||
Ok(total_loss)
|
||||
}
|
||||
}
|
||||
235
src/training/trainer.rs
Normal file
235
src/training/trainer.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
use std::{
|
||||
ffi::CString,
|
||||
path::Path,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
};
|
||||
|
||||
use ort_sys::c_char;
|
||||
|
||||
use super::{trainsys, Checkpoint, Optimizer};
|
||||
use crate::{
|
||||
char_p_to_string,
|
||||
error::{assert_non_null_pointer, status_to_result},
|
||||
Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Trainer {
|
||||
pub(crate) ptr: NonNull<ort_sys::OrtTrainingSession>,
|
||||
train_output_names: Vec<String>,
|
||||
optimizer: Optimizer,
|
||||
ckpt: Checkpoint,
|
||||
_allocator: Allocator
|
||||
}
|
||||
|
||||
impl Trainer {
|
||||
pub fn new(
|
||||
session_options: SessionBuilder,
|
||||
allocator: Allocator,
|
||||
ckpt: Checkpoint,
|
||||
training_model_path: impl AsRef<Path>,
|
||||
eval_model_path: impl AsRef<Path>,
|
||||
optimizer_model_path: impl AsRef<Path>
|
||||
) -> Result<Self> {
|
||||
let training_model_path = crate::util::path_to_os_char(training_model_path);
|
||||
let eval_model_path = crate::util::path_to_os_char(eval_model_path);
|
||||
let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path);
|
||||
|
||||
let env = crate::get_environment()?;
|
||||
|
||||
let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut();
|
||||
trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)];
|
||||
|
||||
let ptr = unsafe { NonNull::new_unchecked(ptr) };
|
||||
|
||||
let mut train_output_len = 0;
|
||||
trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession];
|
||||
let train_output_names = (0..train_output_len)
|
||||
.map(|i| {
|
||||
let mut name_bytes: *mut c_char = std::ptr::null_mut();
|
||||
trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession];
|
||||
let name = match char_p_to_string(name_bytes) {
|
||||
Ok(name) => name,
|
||||
Err(e) => {
|
||||
unsafe { allocator.free(name_bytes) };
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
unsafe { allocator.free(name_bytes) };
|
||||
Ok(name)
|
||||
})
|
||||
.collect::<Result<Vec<String>>>()?;
|
||||
|
||||
Ok(Self {
|
||||
ptr,
|
||||
_allocator: allocator,
|
||||
train_output_names,
|
||||
optimizer: Optimizer(ptr),
|
||||
ckpt
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_from_artifacts(
|
||||
session_options: SessionBuilder,
|
||||
allocator: Allocator,
|
||||
base_dir: impl AsRef<Path>,
|
||||
override_ckpt: Option<Checkpoint>
|
||||
) -> Result<Self> {
|
||||
let base_dir = base_dir.as_ref();
|
||||
let ckpt = if let Some(ckpt) = override_ckpt {
|
||||
ckpt
|
||||
} else {
|
||||
Checkpoint::load(base_dir.join("checkpoint"))?
|
||||
};
|
||||
Self::new(
|
||||
session_options,
|
||||
allocator,
|
||||
ckpt,
|
||||
base_dir.join("training_model.onnx"),
|
||||
base_dir.join("eval_model.onnx"),
|
||||
base_dir.join("optimizer_model.onnx")
|
||||
)
|
||||
}
|
||||
|
||||
pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>(
|
||||
&'s self,
|
||||
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
|
||||
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
|
||||
) -> Result<SessionOutputs<'_, 's>> {
|
||||
match inputs.into() {
|
||||
SessionInputs::ValueSlice(input_values) => match labels.into() {
|
||||
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
|
||||
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None),
|
||||
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
|
||||
},
|
||||
SessionInputs::ValueArray(input_values) => match labels.into() {
|
||||
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
|
||||
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None),
|
||||
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
|
||||
},
|
||||
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
|
||||
}
|
||||
}
|
||||
|
||||
fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
|
||||
&'s self,
|
||||
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
|
||||
run_options: Option<&'r RunOptions>
|
||||
) -> Result<SessionOutputs<'r, 's>> {
|
||||
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];
|
||||
|
||||
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();
|
||||
|
||||
let run_options_ptr = if let Some(run_options) = &run_options {
|
||||
run_options.run_options_ptr.as_ptr()
|
||||
} else {
|
||||
std::ptr::null_mut()
|
||||
};
|
||||
|
||||
trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun];
|
||||
|
||||
let outputs: Vec<Value> = output_tensor_ptrs
|
||||
.into_iter()
|
||||
.map(|tensor_ptr| unsafe {
|
||||
// TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`.
|
||||
// but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣
|
||||
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs))
|
||||
}
|
||||
|
||||
pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>(
|
||||
&'s self,
|
||||
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
|
||||
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
|
||||
) -> Result<SessionOutputs<'_, 's>> {
|
||||
match inputs.into() {
|
||||
SessionInputs::ValueSlice(input_values) => match labels.into() {
|
||||
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
|
||||
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None),
|
||||
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
|
||||
},
|
||||
SessionInputs::ValueArray(input_values) => match labels.into() {
|
||||
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
|
||||
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None),
|
||||
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
|
||||
},
|
||||
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
|
||||
&'s self,
|
||||
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
|
||||
run_options: Option<&'r RunOptions>
|
||||
) -> Result<SessionOutputs<'r, 's>> {
|
||||
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];
|
||||
|
||||
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();
|
||||
|
||||
let run_options_ptr = if let Some(run_options) = &run_options {
|
||||
run_options.run_options_ptr.as_ptr()
|
||||
} else {
|
||||
std::ptr::null_mut()
|
||||
};
|
||||
|
||||
trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun];
|
||||
|
||||
let outputs: Vec<Value> = output_tensor_ptrs
|
||||
.into_iter()
|
||||
.map(|tensor_ptr| unsafe {
|
||||
// TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`.
|
||||
// but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣
|
||||
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs))
|
||||
}
|
||||
|
||||
pub fn export<O: AsRef<str>>(&self, out_path: impl AsRef<Path>, output_names: impl AsRef<[O]>) -> Result<()> {
|
||||
let out_path = crate::util::path_to_os_char(out_path);
|
||||
|
||||
let output_names_ptr: Vec<*const c_char> = output_names
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!()))
|
||||
.map(|n| n.into_raw().cast_const())
|
||||
.collect();
|
||||
|
||||
let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())];
|
||||
|
||||
// Reconvert name ptrs to CString so drop impl is called and memory is freed
|
||||
drop(
|
||||
output_names_ptr
|
||||
.into_iter()
|
||||
.map(|p| {
|
||||
assert_non_null_pointer(p, "c_char for CString")?;
|
||||
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
);
|
||||
|
||||
status_to_result(res).map_err(Error::CreateSession)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn optimizer(&self) -> &Optimizer {
|
||||
&self.optimizer
|
||||
}
|
||||
|
||||
pub fn checkpoint(&self) -> &Checkpoint {
|
||||
&self.ckpt
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Trainer {
|
||||
fn drop(&mut self) {
|
||||
tracing::trace!("dropping trainer");
|
||||
trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())];
|
||||
}
|
||||
}
|
||||
26
src/util.rs
Normal file
26
src/util.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
use std::os::raw::c_char;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
#[cfg(target_family = "windows")]
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::{ffi::OsString, path::Path};
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
type OsCharArray = Vec<u16>;
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
type OsCharArray = Vec<c_char>;
|
||||
|
||||
pub fn path_to_os_char(path: impl AsRef<Path>) -> OsCharArray {
|
||||
let model_path = OsString::from(path.as_ref());
|
||||
#[cfg(target_family = "windows")]
|
||||
let model_path: Vec<u16> = model_path.encode_wide().chain(std::iter::once(0)).collect();
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
let model_path: Vec<c_char> = model_path
|
||||
.as_encoded_bytes()
|
||||
.iter()
|
||||
.chain(std::iter::once(&b'\0'))
|
||||
.map(|b| *b as c_char)
|
||||
.collect();
|
||||
model_path
|
||||
}
|
||||
4
tools/requirements.txt
Normal file
4
tools/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch~=2.3
|
||||
torch-ort~=1.17
|
||||
onnx~=1.16
|
||||
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0
|
||||
140
tools/train-data/mini-clm.py
Normal file
140
tools/train-data/mini-clm.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import math
|
||||
|
||||
import onnx
|
||||
from onnxruntime.training import artifacts
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, *, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if x.dtype != torch.float32:
|
||||
xf = x.to(dtype=torch.float32)
|
||||
else:
|
||||
xf = x
|
||||
output = (xf * torch.sqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps))
|
||||
if x.dtype != torch.float32:
|
||||
output = output.to(dtype=x.dtype)
|
||||
return output * self.weight
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, embedding_dim: int, *, max_seq_length: int = 2048, base: float = 10000.0):
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros(max_seq_length, embedding_dim)
|
||||
position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, embedding_dim, step=2).float() * (-math.log(base) / embedding_dim))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.pe[:, :x.shape[1], :]
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embedding_dim: int, *, rope: RoPE, max_seq_length: int = 2048, n_heads: int = 4):
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.n_heads = n_heads
|
||||
self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
|
||||
self.proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
||||
self.rope = rope
|
||||
self.register_buffer('bias', torch.tril(torch.ones(max_seq_length, max_seq_length))[None, None, :, :], persistent=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
b, t, c = x.size()
|
||||
|
||||
x = self.rope(x)
|
||||
|
||||
q, k, v = self.qkv(x).split(self.embedding_dim, dim=2)
|
||||
q = q.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2)
|
||||
k = k.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2)
|
||||
v = v.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2)
|
||||
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
y = att @ v
|
||||
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
||||
|
||||
return self.proj(y)
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, embedding_dim: int, intermediate_dim: int | None = None):
|
||||
super().__init__()
|
||||
|
||||
intermediate_dim = intermediate_dim or embedding_dim * 4
|
||||
|
||||
self.w1 = nn.Linear(embedding_dim, intermediate_dim * 2, bias=False)
|
||||
self.w2 = nn.Linear(intermediate_dim, embedding_dim, bias=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
return self.w2(F.gelu(gate) * x)
|
||||
|
||||
class Layer(nn.Module):
|
||||
def __init__(self, embedding_dim: int, rope: RoPE):
|
||||
super().__init__()
|
||||
|
||||
self.attn = Attention(embedding_dim, rope=rope)
|
||||
self.norm1 = RMSNorm(embedding_dim)
|
||||
self.ffn = FFN(embedding_dim)
|
||||
self.norm2 = RMSNorm(embedding_dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
class CLM(nn.Module):
|
||||
def __init__(self, embedding_dim: int, n_layers: int, *, vocab_size: int):
|
||||
super().__init__()
|
||||
|
||||
rope = RoPE(embedding_dim)
|
||||
self.layers = nn.ModuleList([Layer(embedding_dim, rope=rope) for _ in range(n_layers)])
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
|
||||
self.norm = RMSNorm(embedding_dim)
|
||||
self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.word_embeddings(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
logits = self.lm_head(self.norm(x))
|
||||
return logits.view(-1, logits.size(-1))
|
||||
|
||||
lm = CLM(256, 4, vocab_size=50257)
|
||||
torch.onnx.export(
|
||||
lm,
|
||||
torch.randint(0, 50256, (1, 64)),
|
||||
f'tools/train-data/mini-clm/model.onnx',
|
||||
input_names=['input_ids'],
|
||||
output_names=['probs'],
|
||||
dynamic_axes={
|
||||
'input_ids': {0: 'batch', 1: 'seq'},
|
||||
'probs': {0: 'batch_seq'}
|
||||
},
|
||||
opset_version=14
|
||||
)
|
||||
|
||||
onnx_model = onnx.load('tools/train-data/mini-clm/model.onnx')
|
||||
requires_grad = [param.name for param in onnx_model.graph.initializer]
|
||||
|
||||
artifacts.generate_artifacts(
|
||||
onnx_model,
|
||||
requires_grad=requires_grad,
|
||||
frozen_params=[],
|
||||
loss=artifacts.LossType.CrossEntropyLoss,
|
||||
optimizer=artifacts.OptimType.AdamW,
|
||||
artifact_directory='tools/train-data/mini-clm'
|
||||
)
|
||||
Reference in New Issue
Block a user