From 96dc67eeac3098371b953938552161e7cce68359 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:29:12 +0800 Subject: [PATCH 01/14] Add CAME Optimizer --- src/lmflow/optim/came.py | 176 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 src/lmflow/optim/came.py diff --git a/src/lmflow/optim/came.py b/src/lmflow/optim/came.py new file mode 100644 index 000000000..009cf3f2a --- /dev/null +++ b/src/lmflow/optim/came.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math + +import torch +from torch.optim import Optimizer + +class CAME(Optimizer): + """Implements CAME algorithm. + This implementation is based on: + `CAME: Confidence-guided Adaptive Memory Efficient Optimization` + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and instability respectively (default: (1e-30, 1e-16)) + clip_threshold (float): threshold of root-mean-square of + final gradient update (default: 1.0) + betas (tuple[float, float, float]): coefficient used for computing running averages of + update, square gradient and instability (default: (0.9, 0.999, 0.9999))) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-16), + clip_threshold=1.0, + betas=(0.9, 0.999, 0.9999), + weight_decay=0.0, + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + betas=betas, + weight_decay=weight_decay, + ) + super(CAME, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return False + + + def _get_options(self, param_shape): + factored = len(param_shape) >= 2 + return factored + + def _rms(self, tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = ( + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + .rsqrt_() + .unsqueeze(-1) + ) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("CAME does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored = self._get_options(grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).type_as(grad) + + state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad) + state["exp_avg_res_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).type_as(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + + state["step"] += 1 + state["RMS"] = self._rms(p.data) + + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(group["betas"][1]).add_( + update.mean(dim=-1), alpha=1.0 - group["betas"][1] + ) + exp_avg_sq_col.mul_(group["betas"][1]).add_( + update.mean(dim=-2), alpha=1.0 - group["betas"][1] + ) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0) + ) + + exp_avg = state["exp_avg"] + exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) + + # Confidence-guided strategy + # Calculation of instability + res = (update - exp_avg)**2 + group["eps"][1] + + if factored: + exp_avg_res_row = state["exp_avg_res_row"] + exp_avg_res_col = state["exp_avg_res_col"] + + exp_avg_res_row.mul_(group["betas"][2]).add_( + res.mean(dim=-1), alpha=1.0 - group["betas"][2] + ) + exp_avg_res_col.mul_(group["betas"][2]).add_( + res.mean(dim=-2), alpha=1.0 - group["betas"][2] + ) + + # Approximation of exponential moving average of instability + res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) + update = res_approx.mul_(exp_avg) + else: + update = exp_avg + + if group["weight_decay"] != 0: + p.data.add_( + p.data, alpha=-group["weight_decay"] * group["lr"] + ) + + update.mul_(group["lr"]) + p.data.add_(-update) + + return loss \ No newline at end of file From 9f75573fe231a8b4ab7e27bd9aea534f7ffb0513 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:31:22 +0800 Subject: [PATCH 02/14] Add Adafactor Optimizer --- src/lmflow/optim/adafactor.py | 158 ++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 src/lmflow/optim/adafactor.py diff --git a/src/lmflow/optim/adafactor.py b/src/lmflow/optim/adafactor.py new file mode 100644 index 000000000..d9a467f63 --- /dev/null +++ b/src/lmflow/optim/adafactor.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +import math +from torch.optim.optimizer import Optimizer +class Adafactor(Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = not lr + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + + factored, use_first_moment = self._get_options(group, grad.shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) + update = exp_avg + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t) + + p_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss From 89afff82930535095605c772b2af27eda78ba914 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:32:10 +0800 Subject: [PATCH 03/14] Add QHAdam Optimizer --- src/lmflow/optim/qhadam.py | 151 +++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 src/lmflow/optim/qhadam.py diff --git a/src/lmflow/optim/qhadam.py b/src/lmflow/optim/qhadam.py new file mode 100644 index 000000000..5b7c42c9e --- /dev/null +++ b/src/lmflow/optim/qhadam.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +from torch.optim.optimizer import Optimizer + +class QHAdam(Optimizer): + r"""Implements the QHAdam optimization algorithm. + + It has been proposed in `Adaptive methods for Nonconvex Optimization`__. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (default: 1e-3) + betas: coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + nus: immediate discount factors used to estimate the gradient and its + square (default: (1.0, 1.0)) + eps: term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay: weight decay (L2 penalty) (default: 0) + decouple_weight_decay: whether to decouple the weight + decay from the gradient-based optimization step (default: False) + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.QHAdam(model.parameters(), lr=0.1) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ https://arxiv.org/abs/1810.06801 + + Note: + Reference code: https://github.com/facebookresearch/qhoptim + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas = (0.9, 0.999), + nus = (1.0, 1.0), + weight_decay: float = 0.0, + decouple_weight_decay: bool = False, + eps: float = 1e-8, + ): + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if weight_decay < 0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + + defaults = { + "lr": lr, + "betas": betas, + "nus": nus, + "weight_decay": weight_decay, + "decouple_weight_decay": decouple_weight_decay, + "eps": eps, + } + super(QHAdam, self).__init__(params, defaults) + + def step(self, closure = None): + """Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + beta1, beta2 = group["betas"] + nu1, nu2 = group["nus"] + weight_decay = group["weight_decay"] + decouple_weight_decay = group["decouple_weight_decay"] + eps = group["eps"] + + for p in group["params"]: + if p.grad is None: + continue + + d_p = p.grad.data + if d_p.is_sparse: + raise RuntimeError( + "QHAdam does not support sparse gradients, " + "please consider SparseAdam instead" + ) + + state = self.state[p] + + if weight_decay != 0: + if decouple_weight_decay: + p.data.mul_(1 - lr * weight_decay) + else: + d_p.add_(p.data, alpha=weight_decay) + + d_p_sq = d_p.mul(d_p) + + if len(state) == 0: + state["beta1_weight"] = 0.0 + state["beta2_weight"] = 0.0 + state["exp_avg"] = torch.zeros_like( + p.data, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p.data, memory_format=torch.preserve_format + ) + + state["beta1_weight"] = 1.0 + beta1 * state["beta1_weight"] + state["beta2_weight"] = 1.0 + beta2 * state["beta2_weight"] + + beta1_weight = state["beta1_weight"] + beta2_weight = state["beta2_weight"] + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + beta1_adj = 1.0 - (1.0 / beta1_weight) + beta2_adj = 1.0 - (1.0 / beta2_weight) + exp_avg.mul_(beta1_adj).add_(d_p, alpha=1.0 - beta1_adj) + exp_avg_sq.mul_(beta2_adj).add_(d_p_sq, alpha=1.0 - beta2_adj) + + avg_grad = exp_avg.mul(nu1) + if nu1 != 1.0: + avg_grad.add_(d_p, alpha=1.0 - nu1) + + avg_grad_rms = exp_avg_sq.mul(nu2) + if nu2 != 1.0: + avg_grad_rms.add_(d_p_sq, alpha=1.0 - nu2) + avg_grad_rms.sqrt_() + if eps != 0.0: + avg_grad_rms.add_(eps) + + p.data.addcdiv_(avg_grad, avg_grad_rms, value=-lr) + + return loss \ No newline at end of file From 5866ff81d98ed7ed971c9a72969e584ff13e0611 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:32:38 +0800 Subject: [PATCH 04/14] Add QHM Optimizer --- src/lmflow/optim/qhm.py | 116 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/lmflow/optim/qhm.py diff --git a/src/lmflow/optim/qhm.py b/src/lmflow/optim/qhm.py new file mode 100644 index 000000000..769feb9bc --- /dev/null +++ b/src/lmflow/optim/qhm.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +from torch.optim.optimizer import Optimizer + + +class QHM(Optimizer): + GRAD = "grad" + DIRECT = "direct" + + r"""Implements quasi-hyperbolic momentum (QHM) optimization algorithm. + + It has been proposed in `Quasi-hyperbolic momentum and Adam for deep + learning`__. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (default: 1e-3) + momentum: momentum factor (:math:`\beta` from the paper) + nu: immediate discount factor (:math:`\nu` from the paper) + weight_decay: weight decay (L2 regularization coefficient, times two) + (default: 0.0) + weight_decay_type: method of applying the weight decay: + ``"grad"`` for accumulation in the gradient + (same as :class:`torch.optim.SGD`) or + ``"direct"`` for direct application to the parameters + (default: ``"grad"``) + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.QHM(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + + __ https://arxiv.org/abs/1810.06801 + + Note: + Reference code: https://github.com/facebookresearch/qhoptim + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0.0, + nu: float = 0.7, + weight_decay: float = 0.0, + weight_decay_type: str = "grad", + ) -> None: + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if weight_decay_type not in (self.GRAD, self.DIRECT): + _type = weight_decay_type + msg = "Invalid weight_decay_type value: {}".format(_type) + raise ValueError(msg) + + defaults = { + "lr": lr, + "momentum": momentum, + "nu": nu, + "weight_decay": weight_decay, + "weight_decay_type": weight_decay_type, + } + super(QHM, self).__init__(params, defaults) + + def step(self, closure = None): + """Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + lr, nu, momentum = group["lr"], group["nu"], group["momentum"] + weight_decay, weight_decay_type = ( + group["weight_decay"], + group["weight_decay_type"], + ) + + for p in group["params"]: + if p.grad is None: + continue + d_p = p.grad.data + param_state = self.state[p] + + if weight_decay != 0: + if weight_decay_type == self.GRAD: + d_p.add_(p.data, alpha=weight_decay) + else: + p.data.mul_(1.0 - lr * weight_decay) + + if len(param_state) == 0: + param_state["momentum_buffer"] = torch.zeros_like( + p.data, memory_format=torch.preserve_format + ) + + momentum_buffer = param_state["momentum_buffer"] + momentum_buffer.mul_(momentum).add_(d_p, alpha=1.0 - momentum) + + p.data.add_(momentum_buffer, alpha=-lr * nu) + p.data.add_(d_p, alpha=-lr * (1.0 - nu)) + + return loss \ No newline at end of file From 21f5aeb5b78e1ddd2143c97c588352c65da84564 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:33:15 +0800 Subject: [PATCH 05/14] Add RMSprop Optimizer --- src/lmflow/optim/rmsprop.py | 155 ++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/lmflow/optim/rmsprop.py diff --git a/src/lmflow/optim/rmsprop.py b/src/lmflow/optim/rmsprop.py new file mode 100644 index 000000000..42c1dd14e --- /dev/null +++ b/src/lmflow/optim/rmsprop.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +r"""Implementation for the RMSprop algorithm.""" +from typing import List, Optional + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +class RMSprop(Optimizer): + def __init__( + self, + params, + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0, + momentum: float = 0, + centered=False, + capturable=False, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= momentum: + raise ValueError(f"Invalid momentum value: {momentum}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= alpha: + raise ValueError(f"Invalid alpha value: {alpha}") + + defaults = dict( + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + capturable=capturable, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("momentum", 0) + group.setdefault("centered", False) + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=torch.get_default_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=torch.get_default_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError("RMSprop does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=torch.get_default_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=torch.get_default_dtype()) + ) + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["momentum"] > 0: + state["momentum_buffer"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["centered"]: + state["grad_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + square_avgs.append(state["square_avg"]) + state_steps.append(state["step"]) + + if group["momentum"] > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if group["centered"]: + grad_avgs.append(state["grad_avg"]) + + return has_complex + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + grad_avgs: List[Tensor] = [] + momentum_buffer_list: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ) + + return loss \ No newline at end of file From f1e3974af25dc2a507e7ef80d8fee75b09152a29 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:35:07 +0800 Subject: [PATCH 06/14] Add SGD Optimizer --- src/lmflow/optim/sgd.py | 119 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 src/lmflow/optim/sgd.py diff --git a/src/lmflow/optim/sgd.py b/src/lmflow/optim/sgd.py new file mode 100644 index 000000000..67647a30a --- /dev/null +++ b/src/lmflow/optim/sgd.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +r"""Implementation for Stochastic Gradient Descent optimizer.""" +from typing import List, Optional + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class SGD(Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + + fused_supported_devices = _get_fused_kernels_supported_devices() + if not all( + p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"] + ): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}." + ) + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: List[Tensor] = [] + grads: List[Tensor] = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + + has_sparse_grad = self._init_group( + group, params, grads, momentum_buffer_list + ) + + if group["momentum"] != 0: + # update momentum_buffers in state + for p, momentum_buffer in zip(params, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss From 77b3619cc6a9a75e6d88a0a296593a35c1f8545b Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:35:40 +0800 Subject: [PATCH 07/14] Add SWATS Optimizer --- src/lmflow/optim/swats.py | 204 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 src/lmflow/optim/swats.py diff --git a/src/lmflow/optim/swats.py b/src/lmflow/optim/swats.py new file mode 100644 index 000000000..8743d336c --- /dev/null +++ b/src/lmflow/optim/swats.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +from torch.optim.optimizer import Optimizer + +class SWATS(Optimizer): + r"""Implements SWATS Optimizer Algorithm. + It has been proposed in `Improving Generalization Performance by + Switching from Adam to SGD`__. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (default: 1e-2) + betas: coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps: term added to the denominator to improve + numerical stability (default: 1e-3) + weight_decay: weight decay (L2 penalty) (default: 0) + amsgrad: whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond` + (default: False) + nesterov: enables Nesterov momentum (default: False) + + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.SWATS(model.parameters(), lr=0.01) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ https://arxiv.org/pdf/1712.07628.pdf + + Note: + Reference code: https://github.com/Mrpatekful/swats + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas = (0.9, 0.999), + eps: float = 1e-3, + weight_decay: float = 0, + amsgrad: bool = False, + nesterov: bool = False, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if weight_decay < 0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + phase="ADAM", + weight_decay=weight_decay, + amsgrad=amsgrad, + nesterov=nesterov, + ) + + super().__init__(params, defaults) + + def __setstate__(self, state) -> None: + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("nesterov", False) + + def step(self, closure = None): + r"""Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for w in group["params"]: + if w.grad is None: + continue + grad = w.grad.data + + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, " + "please consider SparseAdam instead" + ) + + amsgrad = group["amsgrad"] + + state = self.state[w] + + # state initialization + if len(state) == 0: + state["step"] = 0 + # exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + w.data, memory_format=torch.preserve_format + ) + # exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + w.data, memory_format=torch.preserve_format + ) + # moving average for the non-orthogonal projection scaling + state["exp_avg2"] = w.new(1).fill_(0) + if amsgrad: + # maintains max of all exp. moving avg. + # of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + w.data, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg2, exp_avg_sq = ( + state["exp_avg"], + state["exp_avg2"], + state["exp_avg_sq"], + ) + + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + if group["weight_decay"] != 0: + grad.add_(w.data, alpha=group["weight_decay"]) + + # if its SGD phase, take an SGD update and continue + if group["phase"] == "SGD": + if "momentum_buffer" not in state: + buf = state["momentum_buffer"] = torch.clone( + grad + ).detach() + else: + buf = state["momentum_buffer"] + buf.mul_(beta1).add_(grad) + grad = buf + + grad.mul_(1 - beta1) + if group["nesterov"]: + grad.add_(buf, alpha=beta1) + + w.data.add_(grad, alpha=-group["lr"]) + continue + + # decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # maintains the maximum of all 2nd + # moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = ( + group["lr"] * (bias_correction2**0.5) / bias_correction1 + ) + + p = -step_size * (exp_avg / denom) + w.data.add_(p) + + p_view = p.view(-1) + pg = p_view.dot(grad.view(-1)) + + if pg != 0: + # the non-orthognal scaling estimate + scaling = p_view.dot(p_view) / -pg + exp_avg2.mul_(beta2).add_(scaling, alpha=1 - beta2) + + # bias corrected exponential average + corrected_exp_avg = exp_avg2 / bias_correction2 + + # checking criteria of switching to SGD training + if ( + state["step"] > 1 + and corrected_exp_avg.allclose(scaling, rtol=1e-6) + and corrected_exp_avg > 0 + ): + group["phase"] = "SGD" + group["lr"] = corrected_exp_avg.item() + return loss + From 5539993469975a1b9431be9c59e8355334918c20 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:36:49 +0800 Subject: [PATCH 08/14] Add Adahessian Optimizer --- src/lmflow/optim/adahessian.py | 199 +++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 src/lmflow/optim/adahessian.py diff --git a/src/lmflow/optim/adahessian.py b/src/lmflow/optim/adahessian.py new file mode 100644 index 000000000..50ece3455 --- /dev/null +++ b/src/lmflow/optim/adahessian.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math +from typing import List, Optional +import torch +from torch.optim.optimizer import Optimizer + +class Adahessian(Optimizer): + r"""Implements Adahessian Algorithm. + It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer + for Machine Learning`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 0.15) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-4) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + hessian_power (float, optional): Hessian power (default: 0.5) + seed (int, optional): Random number generator seed (default: None) + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.Adahessian(model.parameters(), lr = 1.0) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward(create_graph=True) + >>> optimizer.step() + + __ https://arxiv.org/abs/2006.00719 + + Note: + Reference code: https://github.com/amirgholami/adahessian + """ + + def __init__( + self, + params, + lr: float = 0.15, + betas = (0.9, 0.999), + eps: float = 1e-4, + weight_decay: float = 0, + hessian_power: float = 0.5, + seed: Optional[int] = None, + ) -> None: + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if eps <= 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0.0 <= hessian_power <= 1.0: + raise ValueError( + "Invalid Hessian power value: {}".format(hessian_power) + ) + if seed is not None: + torch.manual_seed(seed) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + hessian_power=hessian_power, + ) + super(Adahessian, self).__init__(params, defaults) + + def get_trace(self, params, grads) -> List[torch.Tensor]: + """Get an estimate of Hessian Trace. + This is done by computing the Hessian vector product with a random + vector v at the current gradient point, to estimate Hessian trace by + computing the gradient of . + :param gradsH: a list of torch variables + :return: a list of torch tensors + """ + + # Check backward was called with create_graph set to True + for i, grad in enumerate(grads): + if grad.grad_fn is None: + msg = ( + "Gradient tensor {:} does not have grad_fn. When " + "calling loss.backward(), make sure the option " + "create_graph is set to True." + ) + raise RuntimeError(msg.format(i)) + + v = [ + 2 + * torch.randint_like( + p, high=2, memory_format=torch.preserve_format + ) + - 1 + for p in params + ] + + # this is for distributed setting with single node and multi-gpus, + # for multi nodes setting, we have not support it yet. + hvs = torch.autograd.grad( + grads, params, grad_outputs=v, only_inputs=True, retain_graph=True + ) + + hutchinson_trace = [] + for hv in hvs: + param_size = hv.size() + if len(param_size) <= 2: # for 0/1/2D tensor + # Hessian diagonal block size is 1 here. + # We use that torch.abs(hv * vi) = hv.abs() + tmp_output = hv.abs() + + elif len(param_size) == 4: # Conv kernel + # Hessian diagonal block size is 9 here: torch.sum() reduces + # the dim 2/3. + # We use that torch.abs(hv * vi) = hv.abs() + tmp_output = torch.mean(hv.abs(), dim=[2, 3], keepdim=True) + hutchinson_trace.append(tmp_output) + + return hutchinson_trace + + def step(self, closure = None): + """Perform a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + params = [] + groups = [] + grads = [] + + # Flatten groups into lists, so that + # hut_traces can be called with lists of parameters + # and grads + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None: + params.append(p) + groups.append(group) + grads.append(p.grad) + + # get the Hessian diagonal + + hut_traces = self.get_trace(params, grads) + + for p, group, grad, hut_trace in zip( + params, groups, grads, hut_traces + ): + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of Hessian diagonal square values + state["exp_hessian_diag_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_hessian_diag_sq = ( + state["exp_avg"], + state["exp_hessian_diag_sq"], + ) + + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad.detach_(), alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_( + hut_trace, hut_trace, value=1 - beta2 + ) + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # make the square root, and the Hessian power + k = group["hessian_power"] + denom = ( + (exp_hessian_diag_sq.sqrt() ** k) + / math.sqrt(bias_correction2) ** k + ).add_(group["eps"]) + + # make update + p.data = p.data - group["lr"] * ( + exp_avg / bias_correction1 / denom + + group["weight_decay"] * p.data + ) + + return loss \ No newline at end of file From 53c84e5d2d5031ef28b843f464fb5855b660488b Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:38:18 +0800 Subject: [PATCH 09/14] Add AdamW Optimizer --- src/lmflow/optim/adamw.py | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/lmflow/optim/adamw.py diff --git a/src/lmflow/optim/adamw.py b/src/lmflow/optim/adamw.py new file mode 100644 index 000000000..07eb7d94c --- /dev/null +++ b/src/lmflow/optim/adamw.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +from torch.optim.optimizer import Optimizer + +class AdamW(Optimizer): + def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(AdamW, self).__init__(params, defaults) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + denom = exp_avg_sq.sqrt().add_(group['eps']) + step_size = group['lr'] * (bias_correction2 ** 0.5) / bias_correction1 + + if group['weight_decay'] != 0: + p.data.addcdiv_(-step_size, exp_avg, denom).add_(-group['weight_decay'] * group['lr'], p.data) + else: + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss From 29f547be2b7f2b96652c3b599e0a25fd41188f90 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:39:27 +0800 Subject: [PATCH 10/14] Add Lion Optimizer --- src/lmflow/optim/lion.py | 98 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/lmflow/optim/lion.py diff --git a/src/lmflow/optim/lion.py b/src/lmflow/optim/lion.py new file mode 100644 index 000000000..c9201f24b --- /dev/null +++ b/src/lmflow/optim/lion.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +from torch.optim.optimizer import Optimizer + +class Lion(Optimizer): + r"""Implements Lion algorithm. + + Addapted from https://github.com/google/automl/tree/master/lion + + The Lion - EvoLved SIgn MOmeNtum - algorithm was proposed in + https://arxiv.org/pdf/2302.06675.pdf. + Lion aims to be more memory efficient than Adam by only tracking momentum. + + Caveats: As detailed in the paper, Lion requires a smaller learning rate + lr, and larger decoupled weight decay to maintain effective weight decay + strength. Also, the gain of Lion increases with the batch size. + Furthermore, Lion was not found to outperform AdamW on some large language + and text/image datasets. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (default: 1e-3) + betas: coefficients used for computing + running averages of gradient and its square (default: (0.95, 0)) + weight_decay: weight decay (L2 penalty) (default: 0) + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.Lion(model.parameters(), lr=0.001) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + """ + + def __init__( + self, + params, + lr: float = 1e-4, + betas = (0.9, 0.99), + weight_decay: float = 0.0, + ): + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if weight_decay < 0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure = None): + r"""Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + + exp_avg = state["exp_avg"] + beta1, beta2 = group["betas"] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + p.add_(torch.sign(update), alpha=-group["lr"]) + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss \ No newline at end of file From 76d66a9bcf8ff6282511bbc897b5122e42e87c2c Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:40:38 +0800 Subject: [PATCH 11/14] Update optimizers.py --- src/lmflow/optim/optimizers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lmflow/optim/optimizers.py b/src/lmflow/optim/optimizers.py index d27e0a9f3..2e40ae753 100644 --- a/src/lmflow/optim/optimizers.py +++ b/src/lmflow/optim/optimizers.py @@ -17,7 +17,17 @@ from lmflow.optim.adan import Adan from lmflow.optim.novograd import NovoGrad from lmflow.optim.adam import Adam +from lmflow.optim.adamw import AdamW +from lmflow.optim.came import CAME +from lmflow.optim.lion import Lion +from lmflow.optim.qhm import QHM +from lmflow.optim.qhadam import QHAdam +from lmflow.optim.rmsprop import RMSprop +from lmflow.optim.swats import SWATS +from lmflow.optim.adafactor import Adafactor +from lmflow.optim.sgd import SGD +from lmflow.optim.adahessian import Adahessian from lmflow.optim.adadelta import Adadelta from lmflow.optim.adagrad import AdaGrad from lmflow.optim.adamw_schedule_free import AdamWScheduleFree -from lmflow.optim.sgd_schedule_free import SGDScheduleFree \ No newline at end of file +from lmflow.optim.sgd_schedule_free import SGDScheduleFree From 27268f074851d403abba6ccf8d88343382ed9ebd Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:41:52 +0800 Subject: [PATCH 12/14] Update args.py --- src/lmflow/args.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 48cf913d4..78637e004 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -35,7 +35,9 @@ class OptimizerNames(): ADABOUND = "adabound" LARS = "lars" LAMB = "lamb" + LION = "lion" ADAMAX = "adamax" + ADAFACTOR = "adafactor" NADAM = "nadam" RADAM = "radam" ADAMP = "adamp" @@ -44,6 +46,14 @@ class OptimizerNames(): SOPHIA = "sophia" ADAN = "adan" ADAM = "adam" + ADAMW = "adamw" + ADAHESSIAN = "adahessian" + CAME = "came" + QHADAM = "qhadam" + QHM = "qhm" + SWATS = "swats" + SGD = "sgd" + RMSPROP = "rmsprop" NOVOGRAD = "novograd" ADADELTA = "adadelta" ADAGRAD = "adagrad" From 7f65cb7cd8a036f767536f718c37f372a1dda525 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:43:55 +0800 Subject: [PATCH 13/14] Update finetuner.py --- src/lmflow/pipeline/finetuner.py | 72 +++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index de5eb6279..7b2d0ab1f 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -317,6 +317,76 @@ def get_optimizer_cls_and_kwargs( "betas": (args.optim_beta1, args.optim_beta2), } optimizer_kwargs.update(adam_kwargs) + elif args.customized_optim == OptimizerNames.RMSPROP: + optimizer_cls = optim.RMSprop + rmsprop_kwargs = { + "momentum": (args.optim_momentum), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(rmsprop_kwargs) + elif args.customized_optim == OptimizerNames.ADAHESSIAN: + optimizer_cls = optim.Adahessian + adahessian_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(adahessian_kwargs) + elif args.customized_optim == OptimizerNames.CAME: + optimizer_cls = optim.CAME + came_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2, args.optim_beta3), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(came_kwargs) + elif args.customized_optim == OptimizerNames.ADAFACTOR: + optimizer_cls = optim.Adafactor + adafactor_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(adafactor_kwargs) + elif args.customized_optim == OptimizerNames.SWATS: + optimizer_cls = optim.SWATS + swats_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(swats_kwargs) + elif args.customized_optim == OptimizerNames.LION: + optimizer_cls = optim.Lion + lion_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(lion_kwargs) + elif args.customized_optim == OptimizerNames.QHADAM: + optimizer_cls = optim.QHAdam + qhadam_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(qhadam_kwargs) + elif args.customized_optim == OptimizerNames.QHM: + optimizer_cls = optim.QHM + qhm_kwargs = { + "momentum": (args.optim_momentum), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(qhm_kwargs) + elif args.customized_optim == OptimizerNames.ADAMW: + optimizer_cls = optim.AdamW + adamw_kwargs = { + "betas": (args.optim_beta1, args.optim_beta2), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(adamw_kwargs) + elif args.customized_optim == OptimizerNames.SGD: + optimizer_cls = optim.SGD + sgd_kwargs = { + "momentum": (args.optim_momentum), + "weight_decay": (args.optim_weight_decay), + } + optimizer_kwargs.update(sgd_kwargs) elif args.customized_optim == OptimizerNames.NOVOGRAD: optimizer_cls = optim.NovoGrad novograd_kwargs = { @@ -629,4 +699,4 @@ def switch_active_layers(self): else: trainer.create_model_card(**kwargs) - return model \ No newline at end of file + return model From 577eb94da593fce72fdefddd40356fddfdf20c27 Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:47:19 +0800 Subject: [PATCH 14/14] Update run_finetune_with_custom_optim.sh --- scripts/run_finetune_with_custom_optim.sh | 64 +++++++++++++++++++++-- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/scripts/run_finetune_with_custom_optim.sh b/scripts/run_finetune_with_custom_optim.sh index dadef185d..cda48c970 100644 --- a/scripts/run_finetune_with_custom_optim.sh +++ b/scripts/run_finetune_with_custom_optim.sh @@ -18,10 +18,8 @@ optim=dummy # Select an optimizer from the following options: # - 'adamw_torch' # - 'adafactor' +# - 'lamb' # - 'sgd' -# - 'lion_8bit' -# - 'lion_32bit' -# - 'rmsprop' # Additional optimizers are shown below learning_rate=1e-5 lr_schedule=cosine @@ -187,6 +185,64 @@ elif [ "${optim}" == "lamb" ]; then optim_suffix_args+=" --optim_beta1 ${beta1}" optim_suffix_args+=" --optim_beta2 ${beta2}" optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "lion" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "adamw" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "adafactor" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "came" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_beta3 ${beta3}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "qhadam" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "adahessian" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "swats" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_beta1 ${beta1}" + optim_suffix_args+=" --optim_beta2 ${beta2}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "qhm" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_momentum ${momentum}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "sgd" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_momentum ${momentum}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" +elif [ "${optim}" == "rmsprop" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_momentum ${momentum}" + optim_suffix_args+=" --optim_weight_decay ${weight_decay}" elif [ "${optim}" == "adamax" ]; then optim_suffix_args="--use_customized_optim 1" optim_suffix_args+=" --customized_optim ${optim}" @@ -270,7 +326,7 @@ else fi # Finetune -exp_id=alpaca_${optim}_lr-${learning_rate}_beta1-${beta1}_beta2-${beta2}_lr-sched-${lr_schedule}_model-$(basename ${model_name_or_path})_batch-size-${batch_size}x${gradient_accumulation_steps}_seed-${seed} +exp_id=alpaca_${optim}_lr-${learning_rate}_beta1-${beta1}_beta2-${beta2}_beta3-${beta3}_momentum-${momentum}_lr-sched-${lr_schedule}_model-$(basename ${model_name_or_path})_batch-size-${batch_size}x${gradient_accumulation_steps}_seed-${seed} echo "$(date): ${exp_id}..." tmp_dir=tmp