Skip to content

Commit

Permalink
✨ Add support for training resuming from steps
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed May 5, 2024
1 parent bd927ca commit 7af269f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
58 changes: 39 additions & 19 deletions hezar/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import os
import random
from typing import Any, Callable, Dict, Tuple
Expand Down Expand Up @@ -153,6 +154,8 @@ def __init__(
self.data_collator = data_collator or self.train_dataset.data_collator
self.train_dataloader, self.eval_dataloader = self._setup_dataloaders()
self.total_steps = len(self.train_dataloader) * self.config.num_epochs
self.config.save_steps = len(self.train_dataloader) if not self.config.save_steps else self.config.save_steps
self.save_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.save_steps)

# Setup optimizer and (optionally) lr scheduler
self.optimizer, self.lr_scheduler = self._setup_optimizers(optimizer, lr_scheduler)
Expand Down Expand Up @@ -188,11 +191,9 @@ def __init__(
self.tensorboard = SummaryWriter(log_dir=self.logs_dir)
self.csv_logger = CSVLogger(logs_dir=self.logs_dir, csv_filename=self.trainer_csv_log_file)

self.current_epoch = 1

# Configure trainer state
self.state = TrainerState(
epoch=self.current_epoch,
epoch=1,
total_epochs=self.config.num_epochs,
metric_for_best_checkpoint=self.config.metric_for_best_model,
logs_dir=self.tensorboard.log_dir,
Expand Down Expand Up @@ -311,17 +312,30 @@ def load_from_checkpoint(self, checkpoint: str | bool = True, load_best: bool =
if os.path.isdir(checkpoint) and load_best:
self.logger.warning("The `load_best` parameter has no effect when `checkpoint` is a path!")

self.state = TrainerState.load(os.path.join(self.checkpoints_dir, self.trainer_state_file))
state_path = os.path.join(self.checkpoints_dir, self.trainer_state_file)

if os.path.isfile(state_path):
self.state = TrainerState.load(state_path)
else:
return

if isinstance(checkpoint, bool):
if load_best:
checkpoint = os.path.join(self.checkpoints_dir, str(self.state.best_checkpoint))
else:
checkpoint = os.path.join(self.checkpoints_dir, str(self.state.epoch))
# Get the most recent checkpoint based on global step
checkpoint = os.path.join(
self.checkpoints_dir,
str(self.state.global_step).zfill(len(str(self.total_steps))),
)
if os.path.isdir(checkpoint):
# Figure out the epoch number
epoch = os.path.basename(checkpoint) if os.path.basename(checkpoint).isdigit() else self.state.epoch
if str(epoch).isdigit():
self.state.epoch = int(epoch)
# Figure out the step and epoch number
step = os.path.basename(checkpoint)
self.state.global_step = int(step) if step.isdigit() else self.state.global_step
self.state.epoch = math.ceil(self.state.global_step / len(self.train_dataloader)) - 1
self.state.epoch_step = self.state.global_step % len(self.train_dataloader)
if self.state.epoch_step == 0:
self.state.epoch += 1
# Load model's state dict
model_path = os.path.join(checkpoint, self.model.model_filename)
if os.path.isfile(model_path):
Expand All @@ -337,11 +351,8 @@ def load_from_checkpoint(self, checkpoint: str | bool = True, load_best: bool =
f"{checkpoint} does not seem to be a valid checkpoint!"
)

self.current_epoch = self.state.epoch + 1
if self.current_epoch >= self.config.num_epochs:
self.logger.warning(
f"The checkpoint at `{os.path.join(self.checkpoints_dir, str(self.state.epoch))}` belongs to the last epoch!"
)
if self.state.global_step >= self.total_steps:
self.logger.warning(f"The checkpoint at `{checkpoint}` belongs to the last training step!")

def load_csv_logs(self, logs_dir=None):
"""
Expand Down Expand Up @@ -494,17 +505,23 @@ def inner_training_loop(self, epoch_num: int):
Returns:
Metrics averages through the full iteration
"""
losses_sum = 0.0
losses_sum = 0
self.model.train()
with tqdm(
self.train_dataloader,
initial=self.state.epoch_step,
unit="batch",
desc=f"Epoch: {epoch_num}/{self.config.num_epochs} ",
bar_format=TQDM_BAR_FORMAT,
ascii=" #",
disable=not self.accelerator.is_local_main_process,
) as iterator:
for step, input_batch in enumerate(iterator):
# TODO make this more efficient in a way that the data loader skips batches without iterating them
# Skip the first batches to reach `epoch_step`
if step < self.state.epoch_step:
continue
# Prepare inputs
input_batch = self.prepare_input_batch(input_batch)
# Training on one batch
with self.accelerator.accumulate(self.model):
Expand All @@ -515,11 +532,13 @@ def inner_training_loop(self, epoch_num: int):
losses_sum += outputs["loss"].item()
avg_loss = losses_sum / (step + 1)
iterator.set_postfix(loss=avg_loss)
# Update steps states
self.state.global_step += 1
self.state.epoch_step += 1

# Save trainer outputs if `save_steps` is hit
if self.config.save_steps and self.state.global_step % self.config.save_steps == 0:
ckpt_path_name = f"step-{str(self.state.global_step).zfill(len(str(self.total_steps)))}"
ckpt_path_name = str(self.state.global_step).zfill(len(str(self.total_steps)))
self.save(os.path.join(self.checkpoints_dir, ckpt_path_name))
# Save Trainer state
self.state.save(
Expand Down Expand Up @@ -572,7 +591,7 @@ def print_info(self):
"""

def _print_info_line(key, value):
line = f" {colorize_text(key, 'bold')}: `{colorize_text(str(value), 'italic')}`"
line = f" {colorize_text(key, 'bold')}: {colorize_text(str(value), 'italic')}"
self.accelerator.print(line)

header = f"{'*' * 20} Training Info {'*' * 20}"
Expand Down Expand Up @@ -621,7 +640,7 @@ def train(self, resume_from_checkpoint: str | bool = None):

self.print_info()

for epoch in range(self.current_epoch, self.config.num_epochs + 1):
for epoch in range(self.state.epoch, self.config.num_epochs + 1):
self.accelerator.print()

# Train on the whole training set
Expand All @@ -630,7 +649,7 @@ def train(self, resume_from_checkpoint: str | bool = None):
# Save checkpoint
if self.accelerator.is_local_main_process:
if self.config.save_enabled:
ckpt_path_name = f"step-{str(self.state.global_step).zfill(len(str(self.total_steps)))}"
ckpt_path_name = str(self.state.global_step).zfill(len(str(self.total_steps)))
self.save(os.path.join(self.checkpoints_dir, ckpt_path_name))

# Evaluate the model on the evaluation set
Expand All @@ -646,6 +665,7 @@ def train(self, resume_from_checkpoint: str | bool = None):

# Update trainer state
self.state.epoch = epoch
self.state.epoch_step = 0
self.state.update_best_results(
metric_value=all_logs[self.config.metric_for_best_model],
objective=self.metrics_handler.objective,
Expand Down
2 changes: 2 additions & 0 deletions hezar/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TrainerState:
epoch: Current epoch number
total_epochs: Total epochs to train the model
global_step: Number of the update steps so far, one step is a full training step (one batch)
epoch_step: Number of the update steps in the current epoch
metric_for_best_checkpoint: The metric key for choosing the best checkpoint (Also given in the TrainerConfig)
best_metric_value: The value of the best checkpoint saved so far
best_checkpoint: Path to the best model checkpoint so far
Expand All @@ -36,6 +37,7 @@ class TrainerState:
epoch: int = 1
total_epochs: int = None
global_step: int = 0
epoch_step: int = 0
metric_for_best_checkpoint: str = None
best_metric_value: float = None
best_checkpoint: str = None
Expand Down

0 comments on commit 7af269f

Please sign in to comment.