feat(training): simple trainer callbacks

This commit is contained in:
Carson M.
2024-08-03 13:33:44 -05:00
parent 227bc8529b
commit 733b7fa329
8 changed files with 478 additions and 242 deletions

View File

@@ -4,14 +4,36 @@ use std::{
path::Path
};
use kdam::BarExt;
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments};
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainerCallbacks, TrainingArguments};
use rand::RngCore;
use tokenizers::Tokenizer;
const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;
struct LoggerCallback {
progress_bar: kdam::Bar
}
impl LoggerCallback {
pub fn new() -> Self {
Self {
progress_bar: kdam::Bar::builder().leave(true).build().unwrap()
}
}
}
impl TrainerCallbacks for LoggerCallback {
fn train_step(&mut self, train_loss: f32, state: &ort::TrainerState, _: &mut ort::TrainerControl<'_>) -> ort::Result<()> {
self.progress_bar.total = state.max_steps;
self.progress_bar.set_postfix(format!("loss={train_loss:.3}"));
let _ = self.progress_bar.update_to(state.iter_step);
Ok(())
}
}
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();
@@ -78,6 +100,7 @@ fn main() -> ort::Result<()> {
.with_lr(7e-5)
.with_max_steps(5000)
.with_ckpt_strategy(CheckpointStrategy::Steps(500))
.with_callbacks(LoggerCallback::new())
)?;
trainer.export("trained-clm.onnx", ["probs"])?;

View File

@@ -27,6 +27,17 @@ unsafe impl Sync for EnvironmentSingleton {}
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };
/// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created.
///
/// Environments can be used to [configure global thread pools](EnvironmentBuilder::with_global_thread_pool), in
/// which all sessions share threads from the environment's pool, and configuring [default execution
/// providers](EnvironmentBuilder::with_execution_providers) for all sessions. In the context of `ort` specifically,
/// environments are also used to configure ONNX Runtime to send log messages through the [`tracing`] crate in Rust.
///
/// For ease of use, and since sessions require an environment to be created, `ort` will automatically create an
/// environment if one is not configured via [`init`] (or [`init_from`]). [`init`] can be called at any point in the
/// program (even after an environment has been automatically created), though every session created before the
/// re-configuration would need to be re-created in order to use the config from the new environment.
#[derive(Debug)]
pub struct Environment {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,

View File

@@ -13,7 +13,10 @@ mod simple;
mod trainer;
pub use self::{
simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments},
simple::{
iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainerCallbacks, TrainerControl, TrainerState,
TrainingArguments
},
trainer::Trainer
};

View File

@@ -1,240 +0,0 @@
use std::{collections::VecDeque, fs, path::PathBuf};
use crate::{Result, SessionInputs};
#[allow(clippy::len_without_is_empty)]
pub trait DataLoader<I, L> {
fn load(&mut self, idx: usize) -> Result<(I, L)>;
fn len(&self) -> Option<usize> {
None
}
}
pub struct IterableDataLoader<T, I, L, C: Fn(&T) -> Result<(I, L)>> {
items: Box<[T]>,
collator: C
}
impl<T, I, L, C: Fn(&T) -> Result<(I, L)>> DataLoader<I, L> for IterableDataLoader<T, I, L, C> {
fn load(&mut self, idx: usize) -> Result<(I, L)> {
(self.collator)(&self.items[idx])
}
fn len(&self) -> Option<usize> {
Some(self.items.len())
}
}
pub fn iterable_data_loader<T, I, L, C: Fn(&T) -> Result<(I, L)>>(iterable: impl Iterator<Item = T>, collator: C) -> IterableDataLoader<T, I, L, C> {
IterableDataLoader {
items: iterable.collect::<Vec<T>>().into_boxed_slice(),
collator
}
}
impl<I, L, F: FnMut(usize) -> Result<(I, L)>> DataLoader<I, L> for F {
fn load(&mut self, idx: usize) -> Result<(I, L)> {
(self)(idx)
}
fn len(&self) -> Option<usize> {
None
}
}
pub enum EvaluationStrategy {
None,
Steps(usize),
Epochs(usize)
}
impl EvaluationStrategy {
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
match self {
Self::None => false,
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
Self::Epochs(epochs) => {
if let Some(dataloader_size) = dataloader_size {
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
} else {
false
}
}
}
}
}
pub enum CheckpointStrategy {
None,
Steps(usize),
Epochs(usize)
}
impl CheckpointStrategy {
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
match self {
Self::None => false,
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
Self::Epochs(epochs) => {
if let Some(dataloader_size) = dataloader_size {
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
} else {
false
}
}
}
}
}
pub struct TrainingArguments<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize> {
loader: Box<dyn DataLoader<I, L>>,
eval_loader: Option<Box<dyn DataLoader<I, L>>>,
eval_strategy: EvaluationStrategy,
ckpt_strategy: CheckpointStrategy,
ckpt_path: PathBuf,
lr: f32,
max_saved_ckpts: usize,
gradient_accumulation_steps: usize,
max_steps: usize,
max_eval_steps: usize
}
impl<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>
TrainingArguments<I, L, NI, NL>
{
pub fn new<D: DataLoader<I, L> + 'static>(train_loader: D) -> Self {
Self {
loader: Box::new(train_loader),
eval_loader: None,
eval_strategy: EvaluationStrategy::None,
ckpt_strategy: CheckpointStrategy::Epochs(1),
ckpt_path: PathBuf::from("checkpoints"),
lr: 1e-4,
gradient_accumulation_steps: 1,
max_saved_ckpts: 1,
max_steps: usize::MAX,
max_eval_steps: usize::MAX
}
}
pub fn with_lr(mut self, lr: f32) -> Self {
self.lr = lr;
self
}
pub fn with_max_steps(mut self, steps: usize) -> Self {
self.max_steps = steps;
self
}
pub fn with_max_eval_steps(mut self, steps: usize) -> Self {
self.max_eval_steps = steps;
self
}
pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
self.gradient_accumulation_steps = steps;
self
}
pub fn with_ckpt_path(mut self, path: impl Into<PathBuf>) -> Self {
self.ckpt_path = path.into();
self
}
pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self {
self.ckpt_strategy = strategy;
self
}
pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self {
self.max_saved_ckpts = max_ckpts;
self
}
pub fn with_eval_loader<D: DataLoader<I, L> + 'static>(mut self, eval_loader: D) -> Self {
self.eval_loader = Some(Box::new(eval_loader));
self
}
pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self {
self.eval_strategy = strategy;
self
}
}
impl super::Trainer {
pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
&self,
mut args: TrainingArguments<I, L, NI, NL>
) -> crate::Result<()> {
let optimizer = self.optimizer();
optimizer.set_lr(args.lr)?;
let mut saved_ckpts = VecDeque::new();
let mut global_step = 0;
for (iter_step, _) in (0..args.max_steps).enumerate() {
let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX);
let (inputs, labels) = args.loader.load(iter_step)?;
let (inputs, labels) = (inputs.into(), labels.into());
let outputs = self.step(inputs, labels)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
println!("epoch={epoch} step={global_step} loss={loss}");
if iter_step % args.gradient_accumulation_steps == 0 {
optimizer.step()?;
optimizer.reset_grad()?;
global_step += 1;
}
if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) {
if !args.ckpt_path.exists() {
let _ = fs::create_dir_all(&args.ckpt_path);
}
let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt"));
self.checkpoint().save(&ckpt_path, true)?;
saved_ckpts.push_front(ckpt_path.clone());
while saved_ckpts.len() > args.max_saved_ckpts {
let Some(old_ckpt) = saved_ckpts.pop_back() else {
break;
};
let _ = fs::remove_file(old_ckpt);
}
}
if args
.eval_strategy
.should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
{
let eval_loss = self.eval_inner(&mut args)?;
println!("eval_loss={eval_loss}");
}
}
Ok(())
}
pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
&self,
args: &mut TrainingArguments<I, L, NI, NL>
) -> crate::Result<f32> {
let Some(eval_loader) = &mut args.eval_loader else {
return Ok(0.0);
};
let mut total_loss = 0.0;
for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
let (inputs, labels) = eval_loader.load(step)?;
let (inputs, labels) = (inputs.into(), labels.into());
let outputs = self.eval_step(inputs, labels)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
}
Ok(total_loss)
}
}

137
src/training/simple/args.rs Normal file
View File

@@ -0,0 +1,137 @@
use std::path::PathBuf;
use super::{DataLoader, TrainerCallbacks};
use crate::session::input::SessionInputs;
pub enum EvaluationStrategy {
None,
Steps(usize),
Epochs(usize)
}
impl EvaluationStrategy {
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
match self {
Self::None => false,
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
Self::Epochs(epochs) => {
if let Some(dataloader_size) = dataloader_size {
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
} else {
false
}
}
}
}
}
pub enum CheckpointStrategy {
None,
Steps(usize),
Epochs(usize)
}
impl CheckpointStrategy {
pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option<usize>) -> bool {
match self {
Self::None => false,
Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0,
Self::Epochs(epochs) => {
if let Some(dataloader_size) = dataloader_size {
iter_step > 0 && iter_step % (dataloader_size * epochs) == 0
} else {
false
}
}
}
}
}
pub struct TrainingArguments<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize> {
pub(crate) loader: Box<dyn DataLoader<I, L>>,
pub(crate) eval_loader: Option<Box<dyn DataLoader<I, L>>>,
pub(crate) eval_strategy: EvaluationStrategy,
pub(crate) ckpt_strategy: CheckpointStrategy,
pub(crate) ckpt_path: PathBuf,
pub(crate) lr: f32,
pub(crate) max_saved_ckpts: usize,
pub(crate) gradient_accumulation_steps: usize,
pub(crate) max_steps: usize,
pub(crate) max_eval_steps: usize,
pub(crate) callbacks: Vec<Box<dyn TrainerCallbacks>>
}
impl<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>
TrainingArguments<I, L, NI, NL>
{
pub fn new<D: DataLoader<I, L> + 'static>(train_loader: D) -> Self {
Self {
loader: Box::new(train_loader),
eval_loader: None,
eval_strategy: EvaluationStrategy::None,
ckpt_strategy: CheckpointStrategy::Epochs(1),
ckpt_path: PathBuf::from("checkpoints"),
lr: 1e-4,
gradient_accumulation_steps: 1,
max_saved_ckpts: 1,
max_steps: usize::MAX,
max_eval_steps: usize::MAX,
callbacks: Vec::new()
}
}
pub fn with_lr(mut self, lr: f32) -> Self {
self.lr = lr;
self
}
pub fn with_max_steps(mut self, steps: usize) -> Self {
self.max_steps = steps;
self
}
pub fn with_epochs(mut self, epochs: f32) -> Self {
self.max_steps = self.loader.len().map(|l| (l as f32 * epochs).trunc() as usize).unwrap_or(usize::MAX);
self
}
pub fn with_max_eval_steps(mut self, steps: usize) -> Self {
self.max_eval_steps = steps;
self
}
pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
self.gradient_accumulation_steps = steps.max(1);
self
}
pub fn with_ckpt_path(mut self, path: impl Into<PathBuf>) -> Self {
self.ckpt_path = path.into();
self
}
pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self {
self.ckpt_strategy = strategy;
self
}
pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self {
self.max_saved_ckpts = max_ckpts;
self
}
pub fn with_eval_loader<D: DataLoader<I, L> + 'static>(mut self, eval_loader: D) -> Self {
self.eval_loader = Some(Box::new(eval_loader));
self
}
pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self {
self.eval_strategy = strategy;
self
}
pub fn with_callbacks(mut self, callbacks: impl TrainerCallbacks + 'static) -> Self {
self.callbacks.push(Box::new(callbacks));
self
}
}

View File

@@ -0,0 +1,111 @@
use std::path::Path;
use super::TrainingArguments;
use crate::{
error::Result,
session::input::SessionInputs,
training::{Checkpoint, Optimizer, Trainer}
};
#[derive(Clone)]
#[non_exhaustive]
pub struct TrainerState {
pub epoch: Option<f32>,
/// The total number of weight updates performed on the model.
pub global_step: usize,
/// The total number of training batches the model has seen.
pub iter_step: usize,
pub gradient_accumulation_steps: usize,
pub max_steps: usize,
pub current_lr: f32
}
impl TrainerState {
pub(crate) fn new<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
args: &TrainingArguments<I, L, NI, NL>
) -> Self {
Self {
epoch: None,
global_step: 0,
iter_step: 0,
gradient_accumulation_steps: args.gradient_accumulation_steps,
max_steps: args.max_steps,
current_lr: args.lr
}
}
}
/// Allows callbacks in [`TrainerCallbacks`] to control the training of the model. This includes halting training,
/// updating the learning rate, or exporting the model.
pub struct TrainerControl<'t> {
pub(crate) halt: bool,
pub(crate) lr: Option<f32>,
trainer: &'t Trainer
}
impl<'t> TrainerControl<'t> {
pub(crate) fn new(trainer: &'t Trainer) -> Self {
Self { halt: false, trainer, lr: None }
}
/// Halts training. Once all callbacks have been called, training will immediately end.
///
/// Halting training will fire [`TrainerCallbacks::end`].
pub fn halt(&mut self) {
self.halt = true;
}
/// Sets the optimizer's learning rate.
pub fn set_lr(&mut self, lr: f32) {
self.lr = Some(lr);
}
/// Export the model as a complete ONNX graph.
pub fn export<O: AsRef<str>>(&self, out_path: impl AsRef<Path>, output_names: impl AsRef<[O]>) -> Result<()> {
self.trainer.export(out_path, output_names)
}
pub fn optimizer(&self) -> &Optimizer {
self.trainer.optimizer()
}
pub fn checkpoint(&self) -> &Checkpoint {
self.trainer.checkpoint()
}
}
#[allow(unused_variables)]
pub trait TrainerCallbacks: Send {
/// Called at the beginning of a new epoch.
fn epoch(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
/// Called when evaluation is about to begin.
fn eval_begin(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
/// Called when evaluation has ended. The `eval_loss` is the average loss over all batches in the evaluation
/// dataset.
fn eval_end(&mut self, eval_loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
/// Called immediately after performing a single forward & backward pass. See also
/// [`TrainerCallbacks::optimizer_step`], which is called immediately after updating the optimizer.
///
/// In the case where [`TrainingArguments::with_gradient_accumulation`] > 1, this will fire as many times as the
/// gradient accumulation steps is configured before [`TrainerCallbacks::optimizer_step`] is fired.
fn train_step(&mut self, train_loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
/// Called immediately after updating the model weights. `loss` is the loss of the last batch.
fn optimizer_step(&mut self, loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
/// Called when training ends, either via [`TrainerControl::halt`], or if the maximum steps have been reached.
fn end(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
}

View File

@@ -0,0 +1,52 @@
use crate::error::Result;
#[allow(clippy::len_without_is_empty)]
pub trait DataLoader<I, L> {
/// Synchronously loads the batch at index `idx`.
fn load(&mut self, idx: usize) -> Result<(I, L)>;
/// The total number of batches in this data loader. The default implementation returns `None`, which indicates the
/// data loader is 'infinite'.
///
/// If `len` does not return `Some` (i.e., it is 'infinite'), you will not be able to use configuration options
/// based on epochs.
fn len(&self) -> Option<usize> {
None
}
}
/// A definitively-sized [`DataLoader`] created from any type that implements [`Iterator`].
///
/// To create an iterable data loader, use [`iterable_data_loader`].
pub struct IterableDataLoader<T, I, L, C: Fn(&T) -> Result<(I, L)>> {
items: Box<[T]>,
collator: C
}
impl<T, I, L, C: Fn(&T) -> Result<(I, L)>> DataLoader<I, L> for IterableDataLoader<T, I, L, C> {
fn load(&mut self, idx: usize) -> Result<(I, L)> {
(self.collator)(&self.items[idx])
}
fn len(&self) -> Option<usize> {
Some(self.items.len())
}
}
/// Creates a definitively-sized [`DataLoader`] from an [`Iterator`] and a corresponding collator function.
pub fn iterable_data_loader<T, I, L, C: Fn(&T) -> Result<(I, L)>>(iterable: impl Iterator<Item = T>, collator: C) -> IterableDataLoader<T, I, L, C> {
IterableDataLoader {
items: iterable.collect::<Vec<T>>().into_boxed_slice(),
collator
}
}
impl<I, L, F: FnMut(usize) -> Result<(I, L)>> DataLoader<I, L> for F {
fn load(&mut self, idx: usize) -> Result<(I, L)> {
(self)(idx)
}
fn len(&self) -> Option<usize> {
None
}
}

139
src/training/simple/mod.rs Normal file
View File

@@ -0,0 +1,139 @@
use std::{collections::VecDeque, fs};
use crate::{error::Result, session::input::SessionInputs, training::Trainer};
mod dataloader;
pub use self::dataloader::{iterable_data_loader, DataLoader, IterableDataLoader};
mod args;
pub use self::args::{CheckpointStrategy, EvaluationStrategy, TrainingArguments};
mod callbacks;
pub use self::callbacks::{TrainerCallbacks, TrainerControl, TrainerState};
macro_rules! callback {
($which:ident($self:expr, $optimizer:expr, $args:expr, $state:expr)) => {
let mut halt = false;
for cb in &mut $args.callbacks {
let mut control = TrainerControl::new($self);
cb.$which(&$state, &mut control)?;
halt = halt || control.halt;
if let Some(lr) = control.lr {
$optimizer.set_lr(lr)?;
}
}
if halt {
return $self.handle_halt(&mut $args.callbacks, &$state);
}
};
($which:ident($self:expr, $optimizer:expr, $args:expr, $state:expr), $($addt:expr),*) => {
let mut halt = false;
for cb in &mut $args.callbacks {
let mut control = TrainerControl::new($self);
cb.$which($($addt,)* &$state, &mut control)?;
halt = halt || control.halt;
if let Some(lr) = control.lr {
$optimizer.set_lr(lr)?;
}
}
if halt {
return $self.handle_halt(&mut $args.callbacks, &$state);
}
};
}
impl Trainer {
pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
&self,
mut args: TrainingArguments<I, L, NI, NL>
) -> Result<()> {
let optimizer = self.optimizer();
optimizer.set_lr(args.lr)?;
let mut saved_ckpts = VecDeque::new();
let mut state = TrainerState::new(&args);
let mut last_epoch = -1.0;
for (iter_step, _) in (0..args.max_steps).enumerate() {
state.iter_step = iter_step;
state.epoch = args.loader.len().map(|dl_len| iter_step as f32 / dl_len as f32);
if let Some(epoch) = state.epoch {
if epoch.trunc() != last_epoch {
callback!(epoch(self, optimizer, args, state));
}
last_epoch = epoch.trunc();
}
let (inputs, labels) = args.loader.load(iter_step)?;
let (inputs, labels) = (inputs.into(), labels.into());
let outputs = self.step(inputs, labels)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
callback!(train_step(self, optimizer, args, state), loss);
if iter_step % args.gradient_accumulation_steps == 0 {
optimizer.step()?;
optimizer.reset_grad()?;
state.global_step += 1;
callback!(optimizer_step(self, optimizer, args, state), loss);
}
if args.ckpt_strategy.should_fire(state.global_step, iter_step, args.loader.len()) {
if !args.ckpt_path.exists() {
let _ = fs::create_dir_all(&args.ckpt_path);
}
let ckpt_path =
args.ckpt_path
.join(format!("epoch={},step={}.ortckpt", state.epoch.map(f32::trunc).unwrap_or(0.0) as usize, state.global_step));
self.checkpoint().save(&ckpt_path, true)?;
saved_ckpts.push_front(ckpt_path.clone());
while saved_ckpts.len() > args.max_saved_ckpts {
let Some(old_ckpt) = saved_ckpts.pop_back() else {
break;
};
let _ = fs::remove_file(old_ckpt);
}
}
if args
.eval_strategy
.should_fire(state.global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
{
callback!(eval_begin(self, optimizer, args, state));
let eval_loss = self.eval_inner(&mut args)?;
callback!(eval_end(self, optimizer, args, state), eval_loss);
}
}
Ok(())
}
fn handle_halt(&self, cbs: &mut Vec<Box<dyn TrainerCallbacks>>, state: &TrainerState) -> Result<()> {
for cb in cbs {
let mut control = TrainerControl::new(self);
cb.end(state, &mut control)?;
}
Ok(())
}
pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
&self,
args: &mut TrainingArguments<I, L, NI, NL>
) -> crate::Result<f32> {
let Some(eval_loader) = &mut args.eval_loader else {
return Ok(0.0);
};
let mut total_loss = 0.0;
for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
let (inputs, labels) = eval_loader.load(step)?;
let (inputs, labels) = (inputs.into(), labels.into());
let outputs = self.eval_step(inputs, labels)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
}
Ok(total_loss)
}
}