-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlr_scheduler.py
35 lines (29 loc) · 1.51 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torch.optim import lr_scheduler
class WarmupPolyLR(lr_scheduler._LRScheduler):
def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3,
warmup_iters=500, warmup_method='linear', last_epoch=-1):
if warmup_method not in ("constant", "linear"):
raise ValueError(
"Only 'constant' or 'linear' warmup_method accepted "
"got {}".format(warmup_method))
self.target_lr = target_lr
self.max_iters = max_iters
self.power = power
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super(WarmupPolyLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
N = self.max_iters - self.warmup_iters
T = self.last_epoch - self.warmup_iters
if self.last_epoch < self.warmup_iters:
if self.warmup_method == 'constant':
warmup_factor = self.warmup_factor
elif self.warmup_method == 'linear':
alpha = float(self.last_epoch) / self.warmup_iters
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
else:
raise ValueError("Unknown warmup type.")
return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs]
factor = pow(1 - T / N, self.power)
return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs]