Skip to content

Commit

Permalink
Revert "work"
Browse files Browse the repository at this point in the history
This reverts commit 761d253.
  • Loading branch information
kylematoba committed Sep 14, 2024
1 parent 761d253 commit 7b7ead9
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 300 deletions.
1 change: 0 additions & 1 deletion lion_pytorch/__init__.py

This file was deleted.

95 changes: 0 additions & 95 deletions lion_pytorch/foreach.py

This file was deleted.

97 changes: 0 additions & 97 deletions lion_pytorch/lion_pytorch.py

This file was deleted.

98 changes: 0 additions & 98 deletions lion_pytorch/triton.py

This file was deleted.

14 changes: 5 additions & 9 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize.metadata import TrainingMetadata

from lion_pytorch import Lion

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -329,7 +327,9 @@ def init_optimizer_and_grad_accumulator(
# Basic optimizer builder
def basic_optimizer_builder(named_param_groups):
optimizer = None

if optimizer_args.optimizer_factory.name == "adamW":

def optimizer(param_groups):
return torch.optim.AdamW(
param_groups,
Expand All @@ -339,20 +339,16 @@ def optimizer(param_groups):
betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2),
fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
)

elif optimizer_args.optimizer_factory.name == "sgd":

def optimizer(param_groups):
return torch.optim.SGD(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
)
elif optimizer_args.optimizer_factory.name == "lion":
def optimizer(param_groups):
return Lion(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
)

else:
raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")

Expand Down

0 comments on commit 7b7ead9

Please sign in to comment.