Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in resuming-from-checkpoint not working with mixed precision training #101

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def seed_worker(worker_id):
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
print("LEN TRAIN LOADER", len(train_loader))
optimizer, scheduler = fetch_optimizer(args, model)
scaler = GradScaler(enabled=args.mixed_precision)

total_steps = 0
if self.global_rank == 0:
Expand All @@ -350,6 +351,9 @@ def seed_worker(worker_id):
if "scheduler" in ckpt:
logging.info("Load scheduler")
scheduler.load_state_dict(ckpt["scheduler"])
if "scaler" in ckpt:
logging.info("Load gradient scaler")
scaler.load_state_dict(ckpt["scaler"])
if "total_steps" in ckpt:
total_steps = ckpt["total_steps"]
logging.info(f"Load total_steps {total_steps}")
Expand All @@ -373,7 +377,6 @@ def seed_worker(worker_id):
model.train()

save_freq = args.save_freq
scaler = GradScaler(enabled=args.mixed_precision)

should_keep_training = True
global_batch_num = 0
Expand Down Expand Up @@ -456,6 +459,7 @@ def seed_worker(worker_id):
"model": model.module.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"scaler": scaler.state_dict(),
"total_steps": total_steps,
}

Expand Down