From 47bafdf657f231bea35a4514d2470b890a2c2824 Mon Sep 17 00:00:00 2001 From: JannikWirtz Date: Sun, 1 Oct 2023 15:16:18 +0200 Subject: [PATCH 1/5] Add optional 'period' parameter to cosine_schedule --- lightly/utils/scheduler.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index c454f757a..5c6c48415 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -1,11 +1,16 @@ import warnings +from typing import Optional import numpy as np import torch def cosine_schedule( - step: int, max_steps: int, start_value: float, end_value: float + step: int, + max_steps: int, + start_value: float, + end_value: float, + period: Optional[int] = None, ) -> float: """Use cosine decay to gradually modify start_value to reach target end_value during iterations. @@ -19,6 +24,9 @@ def cosine_schedule( Starting value. end_value: Target value. + period (optional): + The number of steps over which the cosine function completes its half-cycle. + If not provided, it defaults to max_steps. Returns: Cosine decay value. @@ -28,7 +36,7 @@ def cosine_schedule( raise ValueError("Current step number can't be negative") if max_steps < 1: raise ValueError("Total step number must be >= 1") - if step > max_steps: + if not period and step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", category=RuntimeWarning, @@ -37,6 +45,11 @@ def cosine_schedule( if max_steps == 1: # Avoid division by zero decay = end_value + elif period: # "cycle" based on period, if provided + decay = ( + end_value + - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2 + ) elif step == max_steps: # Special case for Pytorch Lightning which updates LR scheduler also for epoch # after last training epoch. From 8bfc3cde82d0bf3a17a3ac2c98da7ecf306f2e86 Mon Sep 17 00:00:00 2001 From: JannikWirtz Date: Mon, 2 Oct 2023 16:57:31 +0200 Subject: [PATCH 2/5] Adress Feedback --- lightly/utils/scheduler.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index 5c6c48415..5a9fc6d99 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -25,7 +25,7 @@ def cosine_schedule( end_value: Target value. period (optional): - The number of steps over which the cosine function completes its half-cycle. + The number of steps over which the cosine function completes a full cycle. If not provided, it defaults to max_steps. Returns: @@ -36,20 +36,24 @@ def cosine_schedule( raise ValueError("Current step number can't be negative") if max_steps < 1: raise ValueError("Total step number must be >= 1") - if not period and step > max_steps: + if period is None and step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", category=RuntimeWarning, ) + if period is not None and period <= 0: + raise ValueError("Period must be >= 1") - if max_steps == 1: - # Avoid division by zero - decay = end_value - elif period: # "cycle" based on period, if provided + if period is not None: # "cycle" based on period, if provided decay = ( end_value - - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2 + - (end_value - start_value) + * (np.cos(2 * np.pi * step / period) + 1) + / 2 ) + elif max_steps == 1: + # Avoid division by zero + decay = end_value elif step == max_steps: # Special case for Pytorch Lightning which updates LR scheduler also for epoch # after last training epoch. From 51ed466c9765eac3fc8202c99db891fc54466344 Mon Sep 17 00:00:00 2001 From: JannikWirtz Date: Mon, 2 Oct 2023 18:08:05 +0200 Subject: [PATCH 3/5] Add test_cosine_schedule__period --- tests/utils/test_scheduler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/utils/test_scheduler.py b/tests/utils/test_scheduler.py index 00d4bf4b4..7beb6754f 100644 --- a/tests/utils/test_scheduler.py +++ b/tests/utils/test_scheduler.py @@ -23,6 +23,14 @@ def test_cosine_schedule(self) -> None: ): cosine_schedule(11, 10, 0.0, 1.0) + def test_cosine_schedule__period(self) -> None: + self.assertAlmostEqual(cosine_schedule(0, 1, 0, 1.0, period=10), 0.0, 6) + self.assertAlmostEqual(cosine_schedule(3, 1, 0, 2.0, period=10), 1.30901706, 6) + self.assertAlmostEqual(cosine_schedule(10, 1, 0, 1.0, period=10), 0.0, 6) + self.assertAlmostEqual(cosine_schedule(15, 1, 0, 1.0, period=10), 1.0, 6) + with self.assertRaises(ValueError): + cosine_schedule(1, 10, 0.0, 1.0, period=-1) + def test_CosineWarmupScheduler(self) -> None: model = nn.Linear(10, 1) optimizer = torch.optim.SGD( From 4501ca981e9efb4cdb4b4750a836b975a0f43c26 Mon Sep 17 00:00:00 2001 From: JannikWirtz Date: Mon, 2 Oct 2023 18:14:36 +0200 Subject: [PATCH 4/5] Add 'period' parameter to CosineWarmupScheduler --- lightly/utils/scheduler.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index 5a9fc6d99..07e4fd2c2 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -47,9 +47,7 @@ def cosine_schedule( if period is not None: # "cycle" based on period, if provided decay = ( end_value - - (end_value - start_value) - * (np.cos(2 * np.pi * step / period) + 1) - / 2 + - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2 ) elif max_steps == 1: # Avoid division by zero @@ -65,7 +63,7 @@ def cosine_schedule( * (np.cos(np.pi * step / (max_steps - 1)) + 1) / 2 ) - return decay + return float(decay) class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): @@ -100,12 +98,14 @@ def __init__( last_epoch: int = -1, start_value: float = 1.0, end_value: float = 0.001, + period: Optional[int] = None, verbose: bool = False, ) -> None: self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.start_value = start_value self.end_value = end_value + self.period = period super().__init__( optimizer=optimizer, lr_lambda=self.scale_lr, @@ -127,6 +127,14 @@ def scale_lr(self, epoch: int) -> float: """ if epoch < self.warmup_epochs: return self.start_value * (epoch + 1) / self.warmup_epochs + elif self.period is not None: + return cosine_schedule( + step=epoch - self.warmup_epochs, + max_steps=1, + start_value=self.start_value, + end_value=self.end_value, + period=self.period, + ) else: return cosine_schedule( step=epoch - self.warmup_epochs, From 40c80d566908528413347180adc794938c8fe996 Mon Sep 17 00:00:00 2001 From: guarin Date: Tue, 3 Oct 2023 11:33:51 +0000 Subject: [PATCH 5/5] Fix decay typing --- lightly/utils/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index 07e4fd2c2..a698e09ac 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -44,6 +44,7 @@ def cosine_schedule( if period is not None and period <= 0: raise ValueError("Period must be >= 1") + decay: float if period is not None: # "cycle" based on period, if provided decay = ( end_value @@ -63,7 +64,7 @@ def cosine_schedule( * (np.cos(np.pi * step / (max_steps - 1)) + 1) / 2 ) - return float(decay) + return decay class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):