diff --git a/python/jittor/lr_scheduler.py b/python/jittor/lr_scheduler.py index e5febaa2..937bb43f 100644 --- a/python/jittor/lr_scheduler.py +++ b/python/jittor/lr_scheduler.py @@ -32,6 +32,7 @@ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False self.cooldown = cooldown self.n_cd = 0 self.mode = mode + self.mode_worse = None # the worse value for the chosen mode self.threshold = threshold self.threshold_mode = threshold_mode self.loss_best = None @@ -39,7 +40,22 @@ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False self.eps = eps self.last_epoch = 0 self.loss_best = math.inf if mode=="min" else -math.inf - + + @property + def defaults(self): + exclude = set(("defaults", "optimizer")) + return { k:v for k, v in self.__dict__.items() + if k[0] != '_' and k not in exclude and not callable(v) } + + def state_dict(self): + state = {"defaults": self.defaults} + return state + + def load_state_dict(self, state): + for k,v in state["defaults"].items(): + setattr(self, k, v) + self.init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) + def step(self, loss, epoch=None): # convert `metrics` to float, in case it's a zero-dim Tensor loss_now = float(loss) @@ -86,6 +102,21 @@ def better(self, a, b): else: return a > b + self.threshold + def init_is_better(self, mode, threshold, threshold_mode): + if mode not in {'min', 'max'}: + raise ValueError('mode ' + mode + ' is unknown!') + if threshold_mode not in {'rel', 'abs'}: + raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') + + if mode == 'min': + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + class CosineAnnealingLR(object): def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): self.T_max = T_max @@ -96,6 +127,20 @@ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): self.base_lr_pg = [pg.get("lr") for pg in optimizer.param_groups] #TODO set last_epoch is not ready + @property + def defaults(self): + exclude = set(("defaults", "optimizer")) + return { k:v for k, v in self.__dict__.items() + if k[0] != '_' and k not in exclude and not callable(v) } + + def state_dict(self): + state = {"defaults": self.defaults} + return state + + def load_state_dict(self, state): + for k,v in state["defaults"].items(): + setattr(self, k, v) + def get_lr(self, base_lr, now_lr): if self.last_epoch == 0: return base_lr @@ -123,7 +168,21 @@ def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): self.gamma = gamma self.last_epoch = last_epoch self.cur_epoch = 0 - + + @property + def defaults(self): + exclude = set(("defaults", "optimizer")) + return { k:v for k, v in self.__dict__.items() + if k[0] != '_' and k not in exclude and not callable(v) } + + def state_dict(self): + state = {"defaults": self.defaults} + return state + + def load_state_dict(self, state): + for k,v in state["defaults"].items(): + setattr(self, k, v) + def get_gamma(self): if self.last_epoch < 0: if (self.cur_epoch != 0 and (self.cur_epoch + 1) % self.step_size == 0): @@ -155,7 +214,21 @@ def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1): self.gamma = gamma self.last_epoch = last_epoch #TODO set last_epoch is not ready - + + @property + def defaults(self): + exclude = set(("defaults", "optimizer")) + return { k:v for k, v in self.__dict__.items() + if k[0] != '_' and k not in exclude and not callable(v) } + + def state_dict(self): + state = {"defaults": self.defaults} + return state + + def load_state_dict(self, state): + for k,v in state["defaults"].items(): + setattr(self, k, v) + def get_gamma(self): if (self.last_epoch in self.milestones): return self.gamma