Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add state_dict & load_state_dict for LR Schedulers #248

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions python/jittor/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,30 @@ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False
self.cooldown = cooldown
self.n_cd = 0
self.mode = mode
self.mode_worse = None # the worse value for the chosen mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.loss_best = None
self.n_bad = 0
self.eps = eps
self.last_epoch = 0
self.loss_best = math.inf if mode=="min" else -math.inf


@property
def defaults(self):
exclude = set(("defaults", "optimizer"))
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }

def state_dict(self):
state = {"defaults": self.defaults}
return state

def load_state_dict(self, state):
for k,v in state["defaults"].items():
setattr(self, k, v)
self.init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)

def step(self, loss, epoch=None):
# convert `metrics` to float, in case it's a zero-dim Tensor
loss_now = float(loss)
Expand Down Expand Up @@ -86,6 +102,21 @@ def better(self, a, b):
else:
return a > b + self.threshold

def init_is_better(self, mode, threshold, threshold_mode):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if threshold_mode not in {'rel', 'abs'}:
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

if mode == 'min':
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf

self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode

class CosineAnnealingLR(object):
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
self.T_max = T_max
Expand All @@ -96,6 +127,20 @@ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
self.base_lr_pg = [pg.get("lr") for pg in optimizer.param_groups]
#TODO set last_epoch is not ready

@property
def defaults(self):
exclude = set(("defaults", "optimizer"))
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }

def state_dict(self):
state = {"defaults": self.defaults}
return state

def load_state_dict(self, state):
for k,v in state["defaults"].items():
setattr(self, k, v)

def get_lr(self, base_lr, now_lr):
if self.last_epoch == 0:
return base_lr
Expand Down Expand Up @@ -123,7 +168,21 @@ def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
self.gamma = gamma
self.last_epoch = last_epoch
self.cur_epoch = 0


@property
def defaults(self):
exclude = set(("defaults", "optimizer"))
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }

def state_dict(self):
state = {"defaults": self.defaults}
return state

def load_state_dict(self, state):
for k,v in state["defaults"].items():
setattr(self, k, v)

def get_gamma(self):
if self.last_epoch < 0:
if (self.cur_epoch != 0 and (self.cur_epoch + 1) % self.step_size == 0):
Expand Down Expand Up @@ -155,7 +214,21 @@ def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1):
self.gamma = gamma
self.last_epoch = last_epoch
#TODO set last_epoch is not ready


@property
def defaults(self):
exclude = set(("defaults", "optimizer"))
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }

def state_dict(self):
state = {"defaults": self.defaults}
return state

def load_state_dict(self, state):
for k,v in state["defaults"].items():
setattr(self, k, v)

def get_gamma(self):
if (self.last_epoch in self.milestones):
return self.gamma
Expand Down