Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simCLR with temperature schedule #1413

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions lightly/utils/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import warnings
from typing import Optional

import numpy as np
import torch


def cosine_schedule(
step: int, max_steps: int, start_value: float, end_value: float
step: int,
max_steps: int,
start_value: float,
end_value: float,
period: Optional[int] = None,
) -> float:
"""Use cosine decay to gradually modify start_value to reach target end_value during
iterations.
Expand All @@ -19,6 +24,9 @@ def cosine_schedule(
Starting value.
end_value:
Target value.
period (optional):
The number of steps over which the cosine function completes a full cycle.
If not provided, it defaults to max_steps.

Returns:
Cosine decay value.
Expand All @@ -28,13 +36,21 @@ def cosine_schedule(
raise ValueError("Current step number can't be negative")
if max_steps < 1:
raise ValueError("Total step number must be >= 1")
if step > max_steps:
if period is None and step > max_steps:
warnings.warn(
f"Current step number {step} exceeds max_steps {max_steps}.",
category=RuntimeWarning,
)
if period is not None and period <= 0:
raise ValueError("Period must be >= 1")

if max_steps == 1:
decay: float
if period is not None: # "cycle" based on period, if provided
decay = (
end_value
- (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
)
elif max_steps == 1:
# Avoid division by zero
decay = end_value
elif step == max_steps:
Expand Down Expand Up @@ -83,12 +99,14 @@ def __init__(
last_epoch: int = -1,
start_value: float = 1.0,
end_value: float = 0.001,
period: Optional[int] = None,
verbose: bool = False,
) -> None:
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.start_value = start_value
self.end_value = end_value
self.period = period
super().__init__(
optimizer=optimizer,
lr_lambda=self.scale_lr,
Expand All @@ -110,6 +128,14 @@ def scale_lr(self, epoch: int) -> float:
"""
if epoch < self.warmup_epochs:
return self.start_value * (epoch + 1) / self.warmup_epochs
elif self.period is not None:
return cosine_schedule(
step=epoch - self.warmup_epochs,
max_steps=1,
start_value=self.start_value,
end_value=self.end_value,
period=self.period,
)
else:
return cosine_schedule(
step=epoch - self.warmup_epochs,
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def test_cosine_schedule(self) -> None:
):
cosine_schedule(11, 10, 0.0, 1.0)

def test_cosine_schedule__period(self) -> None:
self.assertAlmostEqual(cosine_schedule(0, 1, 0, 1.0, period=10), 0.0, 6)
self.assertAlmostEqual(cosine_schedule(3, 1, 0, 2.0, period=10), 1.30901706, 6)
self.assertAlmostEqual(cosine_schedule(10, 1, 0, 1.0, period=10), 0.0, 6)
self.assertAlmostEqual(cosine_schedule(15, 1, 0, 1.0, period=10), 1.0, 6)
with self.assertRaises(ValueError):
cosine_schedule(1, 10, 0.0, 1.0, period=-1)

def test_CosineWarmupScheduler(self) -> None:
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(
Expand Down