-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path__init__.py
39 lines (32 loc) · 1.46 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) QIU Tian. All rights reserved.
import timm.scheduler as timm_scheduler
import torch.optim.lr_scheduler as torch_scheduler
def build_scheduler(args, optimizer, n_iter_per_epoch):
scheduler_name = args.scheduler.lower()
if scheduler_name == 'cosine':
if args.warmup_epochs > 0 and args.warmup_steps > 0:
raise AssertionError("'args.warmup_epochs' and 'args.warmup_steps' cannot both be positive.")
num_steps = int(args.epochs * n_iter_per_epoch)
warmup_steps = int(args.warmup_epochs * n_iter_per_epoch) if args.warmup_epochs > 0 else args.warmup_steps
return timm_scheduler.CosineLRScheduler(
optimizer,
t_initial=(num_steps - warmup_steps),
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
if scheduler_name == 'step':
return torch_scheduler.StepLR(optimizer, args.step_size, args.gamma)
if scheduler_name == 'multistep':
return timm_scheduler.MultiStepLRScheduler(
optimizer,
decay_t=args.milestones,
decay_rate=args.gamma,
warmup_t=args.warmup_epochs,
warmup_lr_init=args.warmup_lr,
t_in_epochs=True,
)
# return torch_scheduler.MultiStepLR(optimizer, args.milestones, args.gamma)
raise ValueError(f"Scheduler '{scheduler_name}' is not found.")