From 6b5c965cf8787c7158831106952d605017b98065 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frano=20Raji=C4=8D?= Date: Thu, 15 Aug 2024 10:31:14 +0200 Subject: [PATCH] fix bug in resuming-from-checkpoint not working with mixed precision training --- train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index c2b354f..1384655 100644 --- a/train.py +++ b/train.py @@ -326,6 +326,7 @@ def seed_worker(worker_id): train_loader = self.setup_dataloaders(train_loader, move_to_device=False) print("LEN TRAIN LOADER", len(train_loader)) optimizer, scheduler = fetch_optimizer(args, model) + scaler = GradScaler(enabled=args.mixed_precision) total_steps = 0 if self.global_rank == 0: @@ -350,6 +351,9 @@ def seed_worker(worker_id): if "scheduler" in ckpt: logging.info("Load scheduler") scheduler.load_state_dict(ckpt["scheduler"]) + if "scaler" in ckpt: + logging.info("Load gradient scaler") + scaler.load_state_dict(ckpt["scaler"]) if "total_steps" in ckpt: total_steps = ckpt["total_steps"] logging.info(f"Load total_steps {total_steps}") @@ -373,7 +377,6 @@ def seed_worker(worker_id): model.train() save_freq = args.save_freq - scaler = GradScaler(enabled=args.mixed_precision) should_keep_training = True global_batch_num = 0 @@ -456,6 +459,7 @@ def seed_worker(worker_id): "model": model.module.module.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), + "scaler": scaler.state_dict(), "total_steps": total_steps, }