-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
33 lines (28 loc) · 1.17 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Model optimizer and scheduler
"""
import torch
def get_optimizer(model, configs):
optim_configs = {k: v for k, v in configs.items() if k != '_name_'}
if configs['_name_'] == 'adamw':
return torch.optim.AdamW(model.parameters(), **optim_configs)
elif configs['_name_'] == 'sgd':
return torch.optim.SGD(model.parameters(), **optim_configs)
elif configs['_name_'] == 'adam':
return torch.optim.Adam(model.parameters(), **optim_configs)
def get_scheduler(model, optimizer, configs):
if 'scheduler' in configs:
configs = configs['scheduler']
scheduler_configs = {k: v for k, v in configs.items() if k != '_name_'}
if configs['_name_'] == 'timm_cosine':
from timm.scheduler.cosine_lr import CosineLRScheduler
return CosineLRScheduler(optimizer=optimizer, **scheduler_configs)
elif configs['_name_'] == 'plateau':
from torch.optim.lr_scheduler import ReduceLROnPlateau
print(scheduler_configs)
try:
return ReduceLROnPlateau(optimizer=optimizer, **scheduler_configs)
except:
return ReduceLROnPlateau(optimizer=optimizer)
else:
return None