From 7af269f65f7093d442ae39d4ad035023e4fcee3d Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sun, 5 May 2024 14:11:11 +0330 Subject: [PATCH] :sparkles: Add support for training resuming from steps --- hezar/trainer/trainer.py | 58 +++++++++++++++++++++++----------- hezar/trainer/trainer_utils.py | 2 ++ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/hezar/trainer/trainer.py b/hezar/trainer/trainer.py index c6396730..afc98c54 100644 --- a/hezar/trainer/trainer.py +++ b/hezar/trainer/trainer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import os import random from typing import Any, Callable, Dict, Tuple @@ -153,6 +154,8 @@ def __init__( self.data_collator = data_collator or self.train_dataset.data_collator self.train_dataloader, self.eval_dataloader = self._setup_dataloaders() self.total_steps = len(self.train_dataloader) * self.config.num_epochs + self.config.save_steps = len(self.train_dataloader) if not self.config.save_steps else self.config.save_steps + self.save_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.save_steps) # Setup optimizer and (optionally) lr scheduler self.optimizer, self.lr_scheduler = self._setup_optimizers(optimizer, lr_scheduler) @@ -188,11 +191,9 @@ def __init__( self.tensorboard = SummaryWriter(log_dir=self.logs_dir) self.csv_logger = CSVLogger(logs_dir=self.logs_dir, csv_filename=self.trainer_csv_log_file) - self.current_epoch = 1 - # Configure trainer state self.state = TrainerState( - epoch=self.current_epoch, + epoch=1, total_epochs=self.config.num_epochs, metric_for_best_checkpoint=self.config.metric_for_best_model, logs_dir=self.tensorboard.log_dir, @@ -311,17 +312,30 @@ def load_from_checkpoint(self, checkpoint: str | bool = True, load_best: bool = if os.path.isdir(checkpoint) and load_best: self.logger.warning("The `load_best` parameter has no effect when `checkpoint` is a path!") - self.state = TrainerState.load(os.path.join(self.checkpoints_dir, self.trainer_state_file)) + state_path = os.path.join(self.checkpoints_dir, self.trainer_state_file) + + if os.path.isfile(state_path): + self.state = TrainerState.load(state_path) + else: + return + if isinstance(checkpoint, bool): if load_best: checkpoint = os.path.join(self.checkpoints_dir, str(self.state.best_checkpoint)) else: - checkpoint = os.path.join(self.checkpoints_dir, str(self.state.epoch)) + # Get the most recent checkpoint based on global step + checkpoint = os.path.join( + self.checkpoints_dir, + str(self.state.global_step).zfill(len(str(self.total_steps))), + ) if os.path.isdir(checkpoint): - # Figure out the epoch number - epoch = os.path.basename(checkpoint) if os.path.basename(checkpoint).isdigit() else self.state.epoch - if str(epoch).isdigit(): - self.state.epoch = int(epoch) + # Figure out the step and epoch number + step = os.path.basename(checkpoint) + self.state.global_step = int(step) if step.isdigit() else self.state.global_step + self.state.epoch = math.ceil(self.state.global_step / len(self.train_dataloader)) - 1 + self.state.epoch_step = self.state.global_step % len(self.train_dataloader) + if self.state.epoch_step == 0: + self.state.epoch += 1 # Load model's state dict model_path = os.path.join(checkpoint, self.model.model_filename) if os.path.isfile(model_path): @@ -337,11 +351,8 @@ def load_from_checkpoint(self, checkpoint: str | bool = True, load_best: bool = f"{checkpoint} does not seem to be a valid checkpoint!" ) - self.current_epoch = self.state.epoch + 1 - if self.current_epoch >= self.config.num_epochs: - self.logger.warning( - f"The checkpoint at `{os.path.join(self.checkpoints_dir, str(self.state.epoch))}` belongs to the last epoch!" - ) + if self.state.global_step >= self.total_steps: + self.logger.warning(f"The checkpoint at `{checkpoint}` belongs to the last training step!") def load_csv_logs(self, logs_dir=None): """ @@ -494,10 +505,11 @@ def inner_training_loop(self, epoch_num: int): Returns: Metrics averages through the full iteration """ - losses_sum = 0.0 + losses_sum = 0 self.model.train() with tqdm( self.train_dataloader, + initial=self.state.epoch_step, unit="batch", desc=f"Epoch: {epoch_num}/{self.config.num_epochs} ", bar_format=TQDM_BAR_FORMAT, @@ -505,6 +517,11 @@ def inner_training_loop(self, epoch_num: int): disable=not self.accelerator.is_local_main_process, ) as iterator: for step, input_batch in enumerate(iterator): + # TODO make this more efficient in a way that the data loader skips batches without iterating them + # Skip the first batches to reach `epoch_step` + if step < self.state.epoch_step: + continue + # Prepare inputs input_batch = self.prepare_input_batch(input_batch) # Training on one batch with self.accelerator.accumulate(self.model): @@ -515,11 +532,13 @@ def inner_training_loop(self, epoch_num: int): losses_sum += outputs["loss"].item() avg_loss = losses_sum / (step + 1) iterator.set_postfix(loss=avg_loss) + # Update steps states self.state.global_step += 1 + self.state.epoch_step += 1 # Save trainer outputs if `save_steps` is hit if self.config.save_steps and self.state.global_step % self.config.save_steps == 0: - ckpt_path_name = f"step-{str(self.state.global_step).zfill(len(str(self.total_steps)))}" + ckpt_path_name = str(self.state.global_step).zfill(len(str(self.total_steps))) self.save(os.path.join(self.checkpoints_dir, ckpt_path_name)) # Save Trainer state self.state.save( @@ -572,7 +591,7 @@ def print_info(self): """ def _print_info_line(key, value): - line = f" {colorize_text(key, 'bold')}: `{colorize_text(str(value), 'italic')}`" + line = f" {colorize_text(key, 'bold')}: {colorize_text(str(value), 'italic')}" self.accelerator.print(line) header = f"{'*' * 20} Training Info {'*' * 20}" @@ -621,7 +640,7 @@ def train(self, resume_from_checkpoint: str | bool = None): self.print_info() - for epoch in range(self.current_epoch, self.config.num_epochs + 1): + for epoch in range(self.state.epoch, self.config.num_epochs + 1): self.accelerator.print() # Train on the whole training set @@ -630,7 +649,7 @@ def train(self, resume_from_checkpoint: str | bool = None): # Save checkpoint if self.accelerator.is_local_main_process: if self.config.save_enabled: - ckpt_path_name = f"step-{str(self.state.global_step).zfill(len(str(self.total_steps)))}" + ckpt_path_name = str(self.state.global_step).zfill(len(str(self.total_steps))) self.save(os.path.join(self.checkpoints_dir, ckpt_path_name)) # Evaluate the model on the evaluation set @@ -646,6 +665,7 @@ def train(self, resume_from_checkpoint: str | bool = None): # Update trainer state self.state.epoch = epoch + self.state.epoch_step = 0 self.state.update_best_results( metric_value=all_logs[self.config.metric_for_best_model], objective=self.metrics_handler.objective, diff --git a/hezar/trainer/trainer_utils.py b/hezar/trainer/trainer_utils.py index a0a84e91..0f1457b7 100644 --- a/hezar/trainer/trainer_utils.py +++ b/hezar/trainer/trainer_utils.py @@ -28,6 +28,7 @@ class TrainerState: epoch: Current epoch number total_epochs: Total epochs to train the model global_step: Number of the update steps so far, one step is a full training step (one batch) + epoch_step: Number of the update steps in the current epoch metric_for_best_checkpoint: The metric key for choosing the best checkpoint (Also given in the TrainerConfig) best_metric_value: The value of the best checkpoint saved so far best_checkpoint: Path to the best model checkpoint so far @@ -36,6 +37,7 @@ class TrainerState: epoch: int = 1 total_epochs: int = None global_step: int = 0 + epoch_step: int = 0 metric_for_best_checkpoint: str = None best_metric_value: float = None best_checkpoint: str = None