mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat(training): simple trainer callbacks
This commit is contained in:
@@ -4,14 +4,36 @@ use std::{
|
||||
path::Path
|
||||
};
|
||||
|
||||
use kdam::BarExt;
|
||||
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
|
||||
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments};
|
||||
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainerCallbacks, TrainingArguments};
|
||||
use rand::RngCore;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const BATCH_SIZE: usize = 16;
|
||||
const SEQUENCE_LENGTH: usize = 256;
|
||||
|
||||
struct LoggerCallback {
|
||||
progress_bar: kdam::Bar
|
||||
}
|
||||
|
||||
impl LoggerCallback {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
progress_bar: kdam::Bar::builder().leave(true).build().unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainerCallbacks for LoggerCallback {
|
||||
fn train_step(&mut self, train_loss: f32, state: &ort::TrainerState, _: &mut ort::TrainerControl<'_>) -> ort::Result<()> {
|
||||
self.progress_bar.total = state.max_steps;
|
||||
self.progress_bar.set_postfix(format!("loss={train_loss:.3}"));
|
||||
let _ = self.progress_bar.update_to(state.iter_step);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> ort::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
@@ -78,6 +100,7 @@ fn main() -> ort::Result<()> {
|
||||
.with_lr(7e-5)
|
||||
.with_max_steps(5000)
|
||||
.with_ckpt_strategy(CheckpointStrategy::Steps(500))
|
||||
.with_callbacks(LoggerCallback::new())
|
||||
)?;
|
||||
|
||||
trainer.export("trained-clm.onnx", ["probs"])?;
|
||||
|
||||
@@ -27,6 +27,17 @@ unsafe impl Sync for EnvironmentSingleton {}
|
||||
|
||||
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };
|
||||
|
||||
/// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created.
|
||||
///
|
||||
/// Environments can be used to [configure global thread pools](EnvironmentBuilder::with_global_thread_pool), in
|
||||
/// which all sessions share threads from the environment's pool, and configuring [default execution
|
||||
/// providers](EnvironmentBuilder::with_execution_providers) for all sessions. In the context of `ort` specifically,
|
||||
/// environments are also used to configure ONNX Runtime to send log messages through the [`tracing`] crate in Rust.
|
||||
///
|
||||
/// For ease of use, and since sessions require an environment to be created, `ort` will automatically create an
|
||||
/// environment if one is not configured via [`init`] (or [`init_from`]). [`init`] can be called at any point in the
|
||||
/// program (even after an environment has been automatically created), though every session created before the
|
||||
/// re-configuration would need to be re-created in order to use the config from the new environment.
|
||||
#[derive(Debug)]
|
||||
pub struct Environment {
|
||||
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
|
||||
|
||||
@@ -13,7 +13,10 @@ mod simple;
|
||||
mod trainer;
|
||||
|
||||
pub use self::{
|
||||
simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments},
|
||||
simple::{
|
||||
iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainerCallbacks, TrainerControl, TrainerState,
|
||||
TrainingArguments
|
||||
},
|
||||
trainer::Trainer
|
||||
};
|
||||
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
use std::{collections::VecDeque, fs, path::PathBuf};
|
||||
|
||||
use crate::{Result, SessionInputs};
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub trait DataLoader<I, L> {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)>;
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IterableDataLoader<T, I, L, C: Fn(&T) -> Result<(I, L)>> {
|
||||
items: Box<[T]>,
|
||||
collator: C
|
||||
}
|
||||
|
||||
impl<T, I, L, C: Fn(&T) -> Result<(I, L)>> DataLoader<I, L> for IterableDataLoader<T, I, L, C> {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)> {
|
||||
(self.collator)(&self.items[idx])
|
||||
}
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
Some(self.items.len())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iterable_data_loader<T, I, L, C: Fn(&T) -> Result<(I, L)>>(iterable: impl Iterator<Item = T>, collator: C) -> IterableDataLoader<T, I, L, C> {
|
||||
IterableDataLoader {
|
||||
items: iterable.collect::<Vec<T>>().into_boxed_slice(),
|
||||
collator
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, L, F: FnMut(usize) -> Result<(I, L)>> DataLoader<I, L> for F {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)> {
|
||||
(self)(idx)
|
||||
}
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub enum EvaluationStrategy {
|
||||
None,
|
||||
Steps(usize),
|
||||
Epochs(usize)
|
||||
}
|
||||
|
||||
impl EvaluationStrategy {
|
||||
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
|
||||
Self::Epochs(epochs) => {
|
||||
if let Some(dataloader_size) = dataloader_size {
|
||||
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum CheckpointStrategy {
|
||||
None,
|
||||
Steps(usize),
|
||||
Epochs(usize)
|
||||
}
|
||||
|
||||
impl CheckpointStrategy {
|
||||
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
|
||||
Self::Epochs(epochs) => {
|
||||
if let Some(dataloader_size) = dataloader_size {
|
||||
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TrainingArguments<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize> {
|
||||
loader: Box<dyn DataLoader<I, L>>,
|
||||
eval_loader: Option<Box<dyn DataLoader<I, L>>>,
|
||||
eval_strategy: EvaluationStrategy,
|
||||
ckpt_strategy: CheckpointStrategy,
|
||||
ckpt_path: PathBuf,
|
||||
lr: f32,
|
||||
max_saved_ckpts: usize,
|
||||
gradient_accumulation_steps: usize,
|
||||
max_steps: usize,
|
||||
max_eval_steps: usize
|
||||
}
|
||||
|
||||
impl<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>
|
||||
TrainingArguments<I, L, NI, NL>
|
||||
{
|
||||
pub fn new<D: DataLoader<I, L> + 'static>(train_loader: D) -> Self {
|
||||
Self {
|
||||
loader: Box::new(train_loader),
|
||||
eval_loader: None,
|
||||
eval_strategy: EvaluationStrategy::None,
|
||||
ckpt_strategy: CheckpointStrategy::Epochs(1),
|
||||
ckpt_path: PathBuf::from("checkpoints"),
|
||||
lr: 1e-4,
|
||||
gradient_accumulation_steps: 1,
|
||||
max_saved_ckpts: 1,
|
||||
max_steps: usize::MAX,
|
||||
max_eval_steps: usize::MAX
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_lr(mut self, lr: f32) -> Self {
|
||||
self.lr = lr;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_steps(mut self, steps: usize) -> Self {
|
||||
self.max_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_eval_steps(mut self, steps: usize) -> Self {
|
||||
self.max_eval_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
|
||||
self.gradient_accumulation_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ckpt_path(mut self, path: impl Into<PathBuf>) -> Self {
|
||||
self.ckpt_path = path.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self {
|
||||
self.ckpt_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self {
|
||||
self.max_saved_ckpts = max_ckpts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_eval_loader<D: DataLoader<I, L> + 'static>(mut self, eval_loader: D) -> Self {
|
||||
self.eval_loader = Some(Box::new(eval_loader));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self {
|
||||
self.eval_strategy = strategy;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl super::Trainer {
|
||||
pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
&self,
|
||||
mut args: TrainingArguments<I, L, NI, NL>
|
||||
) -> crate::Result<()> {
|
||||
let optimizer = self.optimizer();
|
||||
optimizer.set_lr(args.lr)?;
|
||||
|
||||
let mut saved_ckpts = VecDeque::new();
|
||||
let mut global_step = 0;
|
||||
for (iter_step, _) in (0..args.max_steps).enumerate() {
|
||||
let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX);
|
||||
let (inputs, labels) = args.loader.load(iter_step)?;
|
||||
let (inputs, labels) = (inputs.into(), labels.into());
|
||||
|
||||
let outputs = self.step(inputs, labels)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
println!("epoch={epoch} step={global_step} loss={loss}");
|
||||
|
||||
if iter_step % args.gradient_accumulation_steps == 0 {
|
||||
optimizer.step()?;
|
||||
optimizer.reset_grad()?;
|
||||
global_step += 1;
|
||||
}
|
||||
|
||||
if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) {
|
||||
if !args.ckpt_path.exists() {
|
||||
let _ = fs::create_dir_all(&args.ckpt_path);
|
||||
}
|
||||
|
||||
let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt"));
|
||||
self.checkpoint().save(&ckpt_path, true)?;
|
||||
|
||||
saved_ckpts.push_front(ckpt_path.clone());
|
||||
while saved_ckpts.len() > args.max_saved_ckpts {
|
||||
let Some(old_ckpt) = saved_ckpts.pop_back() else {
|
||||
break;
|
||||
};
|
||||
let _ = fs::remove_file(old_ckpt);
|
||||
}
|
||||
}
|
||||
|
||||
if args
|
||||
.eval_strategy
|
||||
.should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
|
||||
{
|
||||
let eval_loss = self.eval_inner(&mut args)?;
|
||||
println!("eval_loss={eval_loss}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
&self,
|
||||
args: &mut TrainingArguments<I, L, NI, NL>
|
||||
) -> crate::Result<f32> {
|
||||
let Some(eval_loader) = &mut args.eval_loader else {
|
||||
return Ok(0.0);
|
||||
};
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
|
||||
let (inputs, labels) = eval_loader.load(step)?;
|
||||
let (inputs, labels) = (inputs.into(), labels.into());
|
||||
|
||||
let outputs = self.eval_step(inputs, labels)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
|
||||
}
|
||||
|
||||
Ok(total_loss)
|
||||
}
|
||||
}
|
||||
137
src/training/simple/args.rs
Normal file
137
src/training/simple/args.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::{DataLoader, TrainerCallbacks};
|
||||
use crate::session::input::SessionInputs;
|
||||
|
||||
pub enum EvaluationStrategy {
|
||||
None,
|
||||
Steps(usize),
|
||||
Epochs(usize)
|
||||
}
|
||||
|
||||
impl EvaluationStrategy {
|
||||
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
|
||||
Self::Epochs(epochs) => {
|
||||
if let Some(dataloader_size) = dataloader_size {
|
||||
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum CheckpointStrategy {
|
||||
None,
|
||||
Steps(usize),
|
||||
Epochs(usize)
|
||||
}
|
||||
|
||||
impl CheckpointStrategy {
|
||||
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
|
||||
Self::Epochs(epochs) => {
|
||||
if let Some(dataloader_size) = dataloader_size {
|
||||
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TrainingArguments<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize> {
|
||||
pub(crate) loader: Box<dyn DataLoader<I, L>>,
|
||||
pub(crate) eval_loader: Option<Box<dyn DataLoader<I, L>>>,
|
||||
pub(crate) eval_strategy: EvaluationStrategy,
|
||||
pub(crate) ckpt_strategy: CheckpointStrategy,
|
||||
pub(crate) ckpt_path: PathBuf,
|
||||
pub(crate) lr: f32,
|
||||
pub(crate) max_saved_ckpts: usize,
|
||||
pub(crate) gradient_accumulation_steps: usize,
|
||||
pub(crate) max_steps: usize,
|
||||
pub(crate) max_eval_steps: usize,
|
||||
pub(crate) callbacks: Vec<Box<dyn TrainerCallbacks>>
|
||||
}
|
||||
|
||||
impl<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>
|
||||
TrainingArguments<I, L, NI, NL>
|
||||
{
|
||||
pub fn new<D: DataLoader<I, L> + 'static>(train_loader: D) -> Self {
|
||||
Self {
|
||||
loader: Box::new(train_loader),
|
||||
eval_loader: None,
|
||||
eval_strategy: EvaluationStrategy::None,
|
||||
ckpt_strategy: CheckpointStrategy::Epochs(1),
|
||||
ckpt_path: PathBuf::from("checkpoints"),
|
||||
lr: 1e-4,
|
||||
gradient_accumulation_steps: 1,
|
||||
max_saved_ckpts: 1,
|
||||
max_steps: usize::MAX,
|
||||
max_eval_steps: usize::MAX,
|
||||
callbacks: Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_lr(mut self, lr: f32) -> Self {
|
||||
self.lr = lr;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_steps(mut self, steps: usize) -> Self {
|
||||
self.max_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_epochs(mut self, epochs: f32) -> Self {
|
||||
self.max_steps = self.loader.len().map(|l| (l as f32 * epochs).trunc() as usize).unwrap_or(usize::MAX);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_eval_steps(mut self, steps: usize) -> Self {
|
||||
self.max_eval_steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
|
||||
self.gradient_accumulation_steps = steps.max(1);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ckpt_path(mut self, path: impl Into<PathBuf>) -> Self {
|
||||
self.ckpt_path = path.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self {
|
||||
self.ckpt_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self {
|
||||
self.max_saved_ckpts = max_ckpts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_eval_loader<D: DataLoader<I, L> + 'static>(mut self, eval_loader: D) -> Self {
|
||||
self.eval_loader = Some(Box::new(eval_loader));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self {
|
||||
self.eval_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_callbacks(mut self, callbacks: impl TrainerCallbacks + 'static) -> Self {
|
||||
self.callbacks.push(Box::new(callbacks));
|
||||
self
|
||||
}
|
||||
}
|
||||
111
src/training/simple/callbacks.rs
Normal file
111
src/training/simple/callbacks.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use std::path::Path;
|
||||
|
||||
use super::TrainingArguments;
|
||||
use crate::{
|
||||
error::Result,
|
||||
session::input::SessionInputs,
|
||||
training::{Checkpoint, Optimizer, Trainer}
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
#[non_exhaustive]
|
||||
pub struct TrainerState {
|
||||
pub epoch: Option<f32>,
|
||||
/// The total number of weight updates performed on the model.
|
||||
pub global_step: usize,
|
||||
/// The total number of training batches the model has seen.
|
||||
pub iter_step: usize,
|
||||
pub gradient_accumulation_steps: usize,
|
||||
pub max_steps: usize,
|
||||
pub current_lr: f32
|
||||
}
|
||||
|
||||
impl TrainerState {
|
||||
pub(crate) fn new<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
args: &TrainingArguments<I, L, NI, NL>
|
||||
) -> Self {
|
||||
Self {
|
||||
epoch: None,
|
||||
global_step: 0,
|
||||
iter_step: 0,
|
||||
gradient_accumulation_steps: args.gradient_accumulation_steps,
|
||||
max_steps: args.max_steps,
|
||||
current_lr: args.lr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allows callbacks in [`TrainerCallbacks`] to control the training of the model. This includes halting training,
|
||||
/// updating the learning rate, or exporting the model.
|
||||
pub struct TrainerControl<'t> {
|
||||
pub(crate) halt: bool,
|
||||
pub(crate) lr: Option<f32>,
|
||||
trainer: &'t Trainer
|
||||
}
|
||||
|
||||
impl<'t> TrainerControl<'t> {
|
||||
pub(crate) fn new(trainer: &'t Trainer) -> Self {
|
||||
Self { halt: false, trainer, lr: None }
|
||||
}
|
||||
|
||||
/// Halts training. Once all callbacks have been called, training will immediately end.
|
||||
///
|
||||
/// Halting training will fire [`TrainerCallbacks::end`].
|
||||
pub fn halt(&mut self) {
|
||||
self.halt = true;
|
||||
}
|
||||
|
||||
/// Sets the optimizer's learning rate.
|
||||
pub fn set_lr(&mut self, lr: f32) {
|
||||
self.lr = Some(lr);
|
||||
}
|
||||
|
||||
/// Export the model as a complete ONNX graph.
|
||||
pub fn export<O: AsRef<str>>(&self, out_path: impl AsRef<Path>, output_names: impl AsRef<[O]>) -> Result<()> {
|
||||
self.trainer.export(out_path, output_names)
|
||||
}
|
||||
|
||||
pub fn optimizer(&self) -> &Optimizer {
|
||||
self.trainer.optimizer()
|
||||
}
|
||||
|
||||
pub fn checkpoint(&self) -> &Checkpoint {
|
||||
self.trainer.checkpoint()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
pub trait TrainerCallbacks: Send {
|
||||
/// Called at the beginning of a new epoch.
|
||||
fn epoch(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Called when evaluation is about to begin.
|
||||
fn eval_begin(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
/// Called when evaluation has ended. The `eval_loss` is the average loss over all batches in the evaluation
|
||||
/// dataset.
|
||||
fn eval_end(&mut self, eval_loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Called immediately after performing a single forward & backward pass. See also
|
||||
/// [`TrainerCallbacks::optimizer_step`], which is called immediately after updating the optimizer.
|
||||
///
|
||||
/// In the case where [`TrainingArguments::with_gradient_accumulation`] > 1, this will fire as many times as the
|
||||
/// gradient accumulation steps is configured before [`TrainerCallbacks::optimizer_step`] is fired.
|
||||
fn train_step(&mut self, train_loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
/// Called immediately after updating the model weights. `loss` is the loss of the last batch.
|
||||
fn optimizer_step(&mut self, loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Called when training ends, either via [`TrainerControl::halt`], or if the maximum steps have been reached.
|
||||
fn end(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
52
src/training/simple/dataloader.rs
Normal file
52
src/training/simple/dataloader.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use crate::error::Result;
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub trait DataLoader<I, L> {
|
||||
/// Synchronously loads the batch at index `idx`.
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)>;
|
||||
|
||||
/// The total number of batches in this data loader. The default implementation returns `None`, which indicates the
|
||||
/// data loader is 'infinite'.
|
||||
///
|
||||
/// If `len` does not return `Some` (i.e., it is 'infinite'), you will not be able to use configuration options
|
||||
/// based on epochs.
|
||||
fn len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// A definitively-sized [`DataLoader`] created from any type that implements [`Iterator`].
|
||||
///
|
||||
/// To create an iterable data loader, use [`iterable_data_loader`].
|
||||
pub struct IterableDataLoader<T, I, L, C: Fn(&T) -> Result<(I, L)>> {
|
||||
items: Box<[T]>,
|
||||
collator: C
|
||||
}
|
||||
|
||||
impl<T, I, L, C: Fn(&T) -> Result<(I, L)>> DataLoader<I, L> for IterableDataLoader<T, I, L, C> {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)> {
|
||||
(self.collator)(&self.items[idx])
|
||||
}
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
Some(self.items.len())
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a definitively-sized [`DataLoader`] from an [`Iterator`] and a corresponding collator function.
|
||||
pub fn iterable_data_loader<T, I, L, C: Fn(&T) -> Result<(I, L)>>(iterable: impl Iterator<Item = T>, collator: C) -> IterableDataLoader<T, I, L, C> {
|
||||
IterableDataLoader {
|
||||
items: iterable.collect::<Vec<T>>().into_boxed_slice(),
|
||||
collator
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, L, F: FnMut(usize) -> Result<(I, L)>> DataLoader<I, L> for F {
|
||||
fn load(&mut self, idx: usize) -> Result<(I, L)> {
|
||||
(self)(idx)
|
||||
}
|
||||
|
||||
fn len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
139
src/training/simple/mod.rs
Normal file
139
src/training/simple/mod.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::{collections::VecDeque, fs};
|
||||
|
||||
use crate::{error::Result, session::input::SessionInputs, training::Trainer};
|
||||
|
||||
mod dataloader;
|
||||
pub use self::dataloader::{iterable_data_loader, DataLoader, IterableDataLoader};
|
||||
mod args;
|
||||
pub use self::args::{CheckpointStrategy, EvaluationStrategy, TrainingArguments};
|
||||
mod callbacks;
|
||||
pub use self::callbacks::{TrainerCallbacks, TrainerControl, TrainerState};
|
||||
|
||||
macro_rules! callback {
|
||||
($which:ident($self:expr, $optimizer:expr, $args:expr, $state:expr)) => {
|
||||
let mut halt = false;
|
||||
for cb in &mut $args.callbacks {
|
||||
let mut control = TrainerControl::new($self);
|
||||
cb.$which(&$state, &mut control)?;
|
||||
halt = halt || control.halt;
|
||||
if let Some(lr) = control.lr {
|
||||
$optimizer.set_lr(lr)?;
|
||||
}
|
||||
}
|
||||
if halt {
|
||||
return $self.handle_halt(&mut $args.callbacks, &$state);
|
||||
}
|
||||
};
|
||||
($which:ident($self:expr, $optimizer:expr, $args:expr, $state:expr), $($addt:expr),*) => {
|
||||
let mut halt = false;
|
||||
for cb in &mut $args.callbacks {
|
||||
let mut control = TrainerControl::new($self);
|
||||
cb.$which($($addt,)* &$state, &mut control)?;
|
||||
halt = halt || control.halt;
|
||||
if let Some(lr) = control.lr {
|
||||
$optimizer.set_lr(lr)?;
|
||||
}
|
||||
}
|
||||
if halt {
|
||||
return $self.handle_halt(&mut $args.callbacks, &$state);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Trainer {
|
||||
pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
&self,
|
||||
mut args: TrainingArguments<I, L, NI, NL>
|
||||
) -> Result<()> {
|
||||
let optimizer = self.optimizer();
|
||||
optimizer.set_lr(args.lr)?;
|
||||
|
||||
let mut saved_ckpts = VecDeque::new();
|
||||
let mut state = TrainerState::new(&args);
|
||||
let mut last_epoch = -1.0;
|
||||
for (iter_step, _) in (0..args.max_steps).enumerate() {
|
||||
state.iter_step = iter_step;
|
||||
state.epoch = args.loader.len().map(|dl_len| iter_step as f32 / dl_len as f32);
|
||||
|
||||
if let Some(epoch) = state.epoch {
|
||||
if epoch.trunc() != last_epoch {
|
||||
callback!(epoch(self, optimizer, args, state));
|
||||
}
|
||||
|
||||
last_epoch = epoch.trunc();
|
||||
}
|
||||
|
||||
let (inputs, labels) = args.loader.load(iter_step)?;
|
||||
let (inputs, labels) = (inputs.into(), labels.into());
|
||||
|
||||
let outputs = self.step(inputs, labels)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
callback!(train_step(self, optimizer, args, state), loss);
|
||||
|
||||
if iter_step % args.gradient_accumulation_steps == 0 {
|
||||
optimizer.step()?;
|
||||
optimizer.reset_grad()?;
|
||||
state.global_step += 1;
|
||||
callback!(optimizer_step(self, optimizer, args, state), loss);
|
||||
}
|
||||
|
||||
if args.ckpt_strategy.should_fire(state.global_step, iter_step, args.loader.len()) {
|
||||
if !args.ckpt_path.exists() {
|
||||
let _ = fs::create_dir_all(&args.ckpt_path);
|
||||
}
|
||||
|
||||
let ckpt_path =
|
||||
args.ckpt_path
|
||||
.join(format!("epoch={},step={}.ortckpt", state.epoch.map(f32::trunc).unwrap_or(0.0) as usize, state.global_step));
|
||||
self.checkpoint().save(&ckpt_path, true)?;
|
||||
|
||||
saved_ckpts.push_front(ckpt_path.clone());
|
||||
while saved_ckpts.len() > args.max_saved_ckpts {
|
||||
let Some(old_ckpt) = saved_ckpts.pop_back() else {
|
||||
break;
|
||||
};
|
||||
let _ = fs::remove_file(old_ckpt);
|
||||
}
|
||||
}
|
||||
|
||||
if args
|
||||
.eval_strategy
|
||||
.should_fire(state.global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
|
||||
{
|
||||
callback!(eval_begin(self, optimizer, args, state));
|
||||
let eval_loss = self.eval_inner(&mut args)?;
|
||||
callback!(eval_end(self, optimizer, args, state), eval_loss);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_halt(&self, cbs: &mut Vec<Box<dyn TrainerCallbacks>>, state: &TrainerState) -> Result<()> {
|
||||
for cb in cbs {
|
||||
let mut control = TrainerControl::new(self);
|
||||
cb.end(state, &mut control)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
|
||||
&self,
|
||||
args: &mut TrainingArguments<I, L, NI, NL>
|
||||
) -> crate::Result<f32> {
|
||||
let Some(eval_loader) = &mut args.eval_loader else {
|
||||
return Ok(0.0);
|
||||
};
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
|
||||
let (inputs, labels) = eval_loader.load(step)?;
|
||||
let (inputs, labels) = (inputs.into(), labels.into());
|
||||
|
||||
let outputs = self.eval_step(inputs, labels)?;
|
||||
let loss = outputs[0].try_extract_scalar::<f32>()?;
|
||||
total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
|
||||
}
|
||||
|
||||
Ok(total_loss)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user