From 9eaf3d1211629810623b1bd8f6a41999b401bccd Mon Sep 17 00:00:00 2001 From: Kyle Matoba <22180455+kylematoba@users.noreply.github.com> Date: Sat, 14 Sep 2024 23:01:03 +0200 Subject: [PATCH] work --- lion_pytorch/__init__.py | 1 + lion_pytorch/foreach.py | 87 ++++++++++++++++++++++++++++++++ lion_pytorch/lion_pytorch.py | 97 +++++++++++++++++++++++++++++++++++ lion_pytorch/triton.py | 98 ++++++++++++++++++++++++++++++++++++ src/nanotron/helpers.py | 10 +++- 5 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 lion_pytorch/__init__.py create mode 100644 lion_pytorch/foreach.py create mode 100644 lion_pytorch/lion_pytorch.py create mode 100644 lion_pytorch/triton.py diff --git a/lion_pytorch/__init__.py b/lion_pytorch/__init__.py new file mode 100644 index 00000000..b3a7799d --- /dev/null +++ b/lion_pytorch/__init__.py @@ -0,0 +1 @@ +from lion_pytorch.lion_pytorch import Lion diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py new file mode 100644 index 00000000..746952f0 --- /dev/null +++ b/lion_pytorch/foreach.py @@ -0,0 +1,87 @@ +from __future__ import annotations +from typing import Tuple, Callable, Union + +import torch +from torch.optim.optimizer import Optimizer + + +def exists(val): + return val is not None + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + decoupled_weight_decay: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' + + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Union[Callable, None] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group['lr'] + wd = group['weight_decay'] + + beta1, beta2 = group['betas'] + decoupled_wd = self.decoupled_wd + init_lr = self._init_lr + + # maybe decoupled weight decay + if decoupled_wd: + wd /= init_lr + + # accumulate List[Tensor] for foreach inplace updates + params = [] + grads = [] + exp_avgs = [] + + for p in filter(lambda p: exists(p.grad), group['params']): + grad, state = p.grad, self.state[p] + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + params.append(p) + grads.append(grad) + exp_avgs.append(exp_avg) + + # stepweight decay + torch._foreach_mul_(params, 1. - lr * wd) + + # weight update + updates = [t.clone() for t in exp_avgs] + torch._foreach_lerp_(updates, grads, 1. - beta1) + torch._foreach_sign_(updates) + torch._foreach_add_(params, updates, alpha=-lr) + + # decay momentum running average + torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) + return loss diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py new file mode 100644 index 00000000..b0d3a3f8 --- /dev/null +++ b/lion_pytorch/lion_pytorch.py @@ -0,0 +1,97 @@ +from __future__ import annotations +from typing import Tuple, Callable, Union + +import torch +from torch.optim.optimizer import Optimizer + + +def exists(val): + return val is not None + + +def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + p.data.mul_(1. - lr * wd) + + # weight update + update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False, + decoupled_weight_decay: bool = False, + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + + super().__init__(params, defaults) + self.update_fn = update_fn + + if use_triton: + from lion_pytorch.triton import update_fn as triton_update_fn + self.update_fn = triton_update_fn + + @torch.no_grad() + def step( + self, + closure: Union[Callable, None] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + # grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr + grad = p.grad + lr = group['lr'] + wd = group['weight_decay'] + beta1, beta2 = group['betas'] + state= self.state[p] + decoupled_wd = self.decoupled_wd + init_lr = self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr + + # init state - exponential moving average of gradient values + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py new file mode 100644 index 00000000..1dd4696b --- /dev/null +++ b/lion_pytorch/triton.py @@ -0,0 +1,98 @@ +import torch + +try: + import triton + import triton.language as tl +except ImportError as e: + print('triton is not installed, please install by running `pip install triton>=2.2.0`') + exit() + +# triton cuda kernel + +@triton.autotune(configs = [ + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) +@triton.jit +def update_fn_kernel( + p_ptr, + grad_ptr, + exp_avg_ptr, + lr, + wd, + beta1, + beta2, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis = 0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + # offsetted pointers + + offset_p_ptr = p_ptr + offsets + offset_grad_ptr = grad_ptr + offsets + offset_exp_avg_ptr = exp_avg_ptr + offsets + + # load + + p = tl.load(offset_p_ptr, mask = mask) + grad = tl.load(offset_grad_ptr, mask = mask) + exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) + + # stepweight decay + + p = p * (1 - lr * wd) + + # diff between momentum running average and grad + + diff = exp_avg - grad + + # weight update + + update = diff * beta1 + grad + + # torch.sign + + can_update = update != 0 + update_sign = tl.where(update > 0, -lr, lr) + + p = p + update_sign * can_update + + # decay the momentum running average coefficient + + exp_avg = diff * beta2 + grad + + # store new params and momentum running average coefficient + + tl.store(offset_p_ptr, p, mask = mask) + tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) + +def update_fn( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + wd: float, + beta1: float, + beta2: float +): + assert all([t.is_cuda for t in (p, grad, exp_avg)]) + n_elements = p.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + update_fn_kernel[grid]( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2, + n_elements + ) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a82f0294..f52fd9b2 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,6 +44,8 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata +from lion_pytorch import Lion + logger = logging.get_logger(__name__) @@ -341,7 +343,6 @@ def optimizer(param_groups): ) elif optimizer_args.optimizer_factory.name == "sgd": - def optimizer(param_groups): return torch.optim.SGD( param_groups, @@ -349,6 +350,13 @@ def optimizer(param_groups): weight_decay=optimizer_args.weight_decay, ) + elif optimizer_args.optimizer_factory.name == "lion": + def optimizer(param_groups): + return Lion( + param_groups, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + weight_decay=optimizer_args.weight_decay, + ) else: raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")