From ee80b1c0c25c0c3b36e17264733c1bdd16301bca Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 13 Jun 2024 13:46:23 +0330 Subject: [PATCH] :sparkles: Make LR scheduling more robust in the `Trainer` --- docs/tutorial/training/trainer.md | 4 +++- hezar/configs.py | 3 +++ hezar/trainer/trainer.py | 34 +++++++++++++++---------------- hezar/trainer/trainer_utils.py | 6 ++++++ 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/docs/tutorial/training/trainer.md b/docs/tutorial/training/trainer.md index 72e5e5ca..9559a7e8 100644 --- a/docs/tutorial/training/trainer.md +++ b/docs/tutorial/training/trainer.md @@ -95,6 +95,7 @@ Let's explore all the available parameters: - **weight_decay** (float): Optimizer weight decay value. - **lr_scheduler** (LRSchedulerType): Optional learning rate scheduler among `LRSchedulerType` enum. - **lr_scheduler_kwargs** (Dict[str, Any]): LR scheduler instructor kwargs depending on the scheduler type +- **lr_scheduling_steps** (int): Number of steps to perform scheduler stepping. If left as None, will default to the steps in one full epoch. - **batch_size** (int): Training batch size. - **eval_batch_size** (int): Evaluation batch size, defaults to `batch_size` if None. - **gradient_accumulation_steps** (int): Number of updates steps to accumulate before performing a backward/update pass, @@ -242,7 +243,8 @@ the number of batches present in the data loader: does not happen here since it has its own method. 3. Optimization step (`optimization_step`): Does the optimizer stepping and zeros gradients afterward. (Gradient accumulation is handled by the accelerator) -4. Update loss tracker and the trainer states. +4. LR scheduling: Depending on `lr_scheduling_steps`, perform one step of LR scheduling. +5. Update loss tracker and the trainer states. 5. Update and show the loss moving average in the progress bar. 6. Perform saving and logging according to `save_steps` and `log_steps`. 7. Return average loss up until now. (This value is accumulated and averaged since the beginning of the whole training diff --git a/hezar/configs.py b/hezar/configs.py index 923635e0..2c3a0073 100644 --- a/hezar/configs.py +++ b/hezar/configs.py @@ -422,6 +422,8 @@ class TrainerConfig(Config): Optional learning rate scheduler among `LRSchedulerType` enum. lr_scheduler_kwargs (Dict[str, Any]): LR scheduler instructor kwargs depending on the scheduler type + lr_scheduling_steps (int): + Number of steps to perform scheduler stepping. If left as None, will default to the steps in one full epoch. batch_size (int): Training batch size. eval_batch_size (int): @@ -473,6 +475,7 @@ class TrainerConfig(Config): weight_decay: float = 0.0 lr_scheduler: str | LRSchedulerType = None lr_scheduler_kwargs: Dict[str, Any] = None + lr_scheduling_steps: int = None batch_size: int = None eval_batch_size: int = None gradient_accumulation_steps: int = 1 diff --git a/hezar/trainer/trainer.py b/hezar/trainer/trainer.py index 3fbe75ff..b6a98576 100644 --- a/hezar/trainer/trainer.py +++ b/hezar/trainer/trainer.py @@ -55,6 +55,7 @@ get_distributed_logger, resolve_logdir, write_to_tensorboard, + get_lr_scheduler_type, ) @@ -72,6 +73,7 @@ } lr_schedulers = { LRSchedulerType.LAMBDA: torch.optim.lr_scheduler.LambdaLR, + LRSchedulerType.REDUCE_ON_PLATEAU: torch.optim.lr_scheduler.ReduceLROnPlateau, LRSchedulerType.STEP: torch.optim.lr_scheduler.StepLR, LRSchedulerType.MULTI_STEP: torch.optim.lr_scheduler.MultiStepLR, LRSchedulerType.ONE_CYCLE: torch.optim.lr_scheduler.OneCycleLR, @@ -186,7 +188,10 @@ def __init__( # Setup optimizer and (optionally) lr scheduler self.optimizer, self.lr_scheduler = self._create_optimizers(optimizer, lr_scheduler) + self.lr_scheduling_steps = self.config.lr_scheduling_steps or self.steps_in_epoch + self.lr_scheduler_type = get_lr_scheduler_type(self.lr_scheduler, lr_schedulers) + # Move objects to the accelerator self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) @@ -464,21 +469,6 @@ def optimization_step(self): self.optimizer.step() self.optimizer.zero_grad() - def lr_scheduler_step(self, metrics=None): - """ - Perform the learning rate scheduling step - - Args: - metrics: one or multiple values that the scheduler watches to either perform step function or not. Only - works for `ReduceLROnPlateau`. - """ - if self.lr_scheduler is not None: - if isinstance(self.lr_scheduler, lr_schedulers[LRSchedulerType.REDUCE_ON_PLATEAU]): - if metrics: - self.lr_scheduler.step(metrics) - else: - self.lr_scheduler.step() - def training_step(self, input_batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ Train one batch of data and return loss and model outputs @@ -582,6 +572,14 @@ def inner_training_loop(self, epoch_num: int): self.state.loss_tracker_sum = self.train_loss_tracker.sum accumulated_loss = 0 + # Scheduler step + if ( + self.lr_scheduler is not None and + self.state.global_step % self.lr_scheduling_steps == 0 and + self.lr_scheduler_type != LRSchedulerType.REDUCE_ON_PLATEAU + ): + self.lr_scheduler.step() + # Save trainer outputs if `save_steps` is hit if self.config.save_steps and self.state.global_step % self.config.save_steps == 0: ckpt_path_name = str(self.state.global_step).zfill(len(str(self.total_steps))) @@ -593,6 +591,7 @@ def inner_training_loop(self, epoch_num: int): self.trainer_state_file, ) ) + # Log loss running mean if self.config.log_steps and self.state.global_step % self.config.log_steps == 0: loss_mean = {"train.loss": self.train_loss_tracker.avg} @@ -733,8 +732,9 @@ def train(self, resume_from_checkpoint="deprecated"): } metrics_logs.update(evaluation_logs) - # LR scheduler step - self.lr_scheduler_step(metrics_logs[self.config.metric_for_best_model]) + # LR scheduler step (only for reduce on plateau) + if self.lr_scheduler_type == LRSchedulerType.REDUCE_ON_PLATEAU: + self.lr_scheduler.step(metrics_logs[self.config.metric_for_best_model]) # Update trainer state self.state.epoch = epoch diff --git a/hezar/trainer/trainer_utils.py b/hezar/trainer/trainer_utils.py index 8cd9d759..071318f9 100644 --- a/hezar/trainer/trainer_utils.py +++ b/hezar/trainer/trainer_utils.py @@ -15,6 +15,7 @@ "write_to_tensorboard", "resolve_logdir", "get_distributed_logger", + "get_lr_scheduler_type", ] @@ -188,3 +189,8 @@ def get_distributed_logger(name: str, level: str = None, fmt: str = None): logger.logger.addHandler(handler) return logger + +def get_lr_scheduler_type(lr_scheduler, schedulers_mapping: dict): + for name, scheduler_cls in schedulers_mapping.items(): + if isinstance(lr_scheduler, scheduler_cls): + return name