Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
kylematoba committed Sep 14, 2024
1 parent 7b7ead9 commit 9eaf3d1
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 1 deletion.
1 change: 1 addition & 0 deletions lion_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from lion_pytorch.lion_pytorch import Lion
87 changes: 87 additions & 0 deletions lion_pytorch/foreach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations
from typing import Tuple, Callable, Union

import torch
from torch.optim.optimizer import Optimizer


def exists(val):
return val is not None


class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
decoupled_weight_decay: bool = False
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions'

self._init_lr = lr
self.decoupled_wd = decoupled_weight_decay

defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay
)
super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Union[Callable, None] = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
lr = group['lr']
wd = group['weight_decay']

beta1, beta2 = group['betas']
decoupled_wd = self.decoupled_wd
init_lr = self._init_lr

# maybe decoupled weight decay
if decoupled_wd:
wd /= init_lr

# accumulate List[Tensor] for foreach inplace updates
params = []
grads = []
exp_avgs = []

for p in filter(lambda p: exists(p.grad), group['params']):
grad, state = p.grad, self.state[p]
# init state - exponential moving average of gradient values

if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)

exp_avg = state['exp_avg']

params.append(p)
grads.append(grad)
exp_avgs.append(exp_avg)

# stepweight decay
torch._foreach_mul_(params, 1. - lr * wd)

# weight update
updates = [t.clone() for t in exp_avgs]
torch._foreach_lerp_(updates, grads, 1. - beta1)
torch._foreach_sign_(updates)
torch._foreach_add_(params, updates, alpha=-lr)

# decay momentum running average
torch._foreach_lerp_(exp_avgs, grads, 1. - beta2)
return loss
97 changes: 97 additions & 0 deletions lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import Tuple, Callable, Union

import torch
from torch.optim.optimizer import Optimizer


def exists(val):
return val is not None


def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay
p.data.mul_(1. - lr * wd)

# weight update
update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_()
p.add_(update, alpha=-lr)

# decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2)


class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False,
decoupled_weight_decay: bool = False,
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])

self._init_lr = lr
self.decoupled_wd = decoupled_weight_decay

defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay
)

super().__init__(params, defaults)
self.update_fn = update_fn

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
self.update_fn = triton_update_fn

@torch.no_grad()
def step(
self,
closure: Union[Callable, None] = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

# grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr
grad = p.grad
lr = group['lr']
wd = group['weight_decay']
beta1, beta2 = group['betas']
state= self.state[p]
decoupled_wd = self.decoupled_wd
init_lr = self._init_lr

# maybe decoupled weight decay

if decoupled_wd:
wd /= init_lr

# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)

exp_avg = state['exp_avg']

self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)

return loss
98 changes: 98 additions & 0 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch

try:
import triton
import triton.language as tl
except ImportError as e:
print('triton is not installed, please install by running `pip install triton>=2.2.0`')
exit()

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr'])
@triton.jit
def update_fn_kernel(
p_ptr,
grad_ptr,
exp_avg_ptr,
lr,
wd,
beta1,
beta2,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis = 0)

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

mask = offsets < n_elements

# offsetted pointers

offset_p_ptr = p_ptr + offsets
offset_grad_ptr = grad_ptr + offsets
offset_exp_avg_ptr = exp_avg_ptr + offsets

# load

p = tl.load(offset_p_ptr, mask = mask)
grad = tl.load(offset_grad_ptr, mask = mask)
exp_avg = tl.load(offset_exp_avg_ptr, mask = mask)

# stepweight decay

p = p * (1 - lr * wd)

# diff between momentum running average and grad

diff = exp_avg - grad

# weight update

update = diff * beta1 + grad

# torch.sign

can_update = update != 0
update_sign = tl.where(update > 0, -lr, lr)

p = p + update_sign * can_update

# decay the momentum running average coefficient

exp_avg = diff * beta2 + grad

# store new params and momentum running average coefficient

tl.store(offset_p_ptr, p, mask = mask)
tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)

def update_fn(
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
lr: float,
wd: float,
beta1: float,
beta2: float
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])
n_elements = p.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

update_fn_kernel[grid](
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2,
n_elements
)
10 changes: 9 additions & 1 deletion src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize.metadata import TrainingMetadata

from lion_pytorch import Lion

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -341,14 +343,20 @@ def optimizer(param_groups):
)

elif optimizer_args.optimizer_factory.name == "sgd":

def optimizer(param_groups):
return torch.optim.SGD(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
)

elif optimizer_args.optimizer_factory.name == "lion":
def optimizer(param_groups):
return Lion(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
)
else:
raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")

Expand Down

0 comments on commit 9eaf3d1

Please sign in to comment.