Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stop module #301

Merged
merged 31 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d0c454f
add early stopper
sanaAyrml Dec 5, 2024
8938bf5
Merge branch 'main' into sa_early_stop
sanaAyrml Dec 5, 2024
6136a90
test smoke tests
sanaAyrml Dec 5, 2024
6f20445
Separate early_stopper and snapshotter
sanaAyrml Dec 5, 2024
52546e9
Temporary commit
sanaAyrml Jan 2, 2025
fe610db
Merge branch 'sa_early_stop' of https://github.com/VectorInstitute/FL…
sanaAyrml Jan 2, 2025
92fe751
Firx extra early stopper implementation
sanaAyrml Jan 2, 2025
c73d04c
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 2, 2025
fe57eea
add seriazable snapshotter tests
sanaAyrml Jan 2, 2025
2d939b3
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
5627c72
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
311de1c
add metric manager snappshotter test
sanaAyrml Jan 8, 2025
0471905
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
27bf361
Resolve conflict with main
sanaAyrml Jan 8, 2025
3142889
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
f4ad580
update precommit type changes
sanaAyrml Jan 8, 2025
9d09429
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
00320fc
add snappshotter other tests
sanaAyrml Jan 9, 2025
5c48474
Add docstrings
sanaAyrml Jan 9, 2025
bf2504d
Add docstring
sanaAyrml Jan 9, 2025
476343e
Update doc strings
sanaAyrml Jan 9, 2025
dd3e964
Ignoring a vulnerability without a fix yet
emersodb Jan 9, 2025
080fd90
Address review comments
sanaAyrml Jan 23, 2025
a9fe020
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 23, 2025
0a4d6b8
Update memory issues
sanaAyrml Jan 23, 2025
fda2946
Add str to logging mode
sanaAyrml Jan 23, 2025
14da12d
Address some comments
sanaAyrml Jan 24, 2025
a1aecfd
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2025
a0a8128
Bring some checks into should_stop function of early_stopper
sanaAyrml Jan 24, 2025
da842cd
Merge branch 'main' into sa_early_stop
emersodb Jan 24, 2025
171fcc1
Small typo fix
emersodb Jan 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ jobs:
virtual-environment: .venv/
# Ignoring vulnerability in cryptography
# Fix is 43.0.1 but flwr 1.9 depends on < 43
# GHSA-cjgq-5qmw-rcj6 is a Keras vulnerability that has no fix yet
ignore-vulns: |
GHSA-h4gh-qq45-vh27
GHSA-q34m-jh98-gwm2
GHSA-f9vj-2wh5-fj8j
GHSA-cjgq-5qmw-rcj6
38 changes: 34 additions & 4 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler
emersodb marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import DataLoader

from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule
Expand Down Expand Up @@ -159,6 +159,9 @@ def get_parameters(self, config: Config) -> NDArrays:
# Need all parameters even if normally exchanging partial
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
if hasattr(self, "early_stopper") and self.early_stopper.patience == 0:
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
log(INFO, "Loading saved best model's state before sending model to server.")
self.early_stopper.load_snapshot(["model"])
assert self.model is not None and self.parameter_exchanger is not None
return self.parameter_exchanger.push_parameters(self.model, config=config)

Expand Down Expand Up @@ -641,6 +644,10 @@ def train_by_epochs(
self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps)
self.total_steps += 1
steps_this_round += 1
if hasattr(self, "early_stopper"):
if self.total_steps % self.early_stopper.interval_steps == 0 and self.early_stopper.should_stop():
log(INFO, "Early stopping criterion met. Stopping training.")
break

# Log and report results
metrics = self.train_metric_manager.compute()
Expand Down Expand Up @@ -709,6 +716,10 @@ def train_by_steps(
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, None, self.total_steps)
self.total_steps += 1
if hasattr(self, "early_stopper"):
if self.total_steps % self.early_stopper.interval_steps == 0 and self.early_stopper.should_stop():
emersodb marked this conversation as resolved.
Show resolved Hide resolved
log(INFO, "Early stopping criterion met. Stopping training.")
break

loss_dict = self.train_loss_meter.compute().as_dict()
metrics = self.train_metric_manager.compute()
Expand Down Expand Up @@ -879,9 +890,28 @@ def setup_client(self, config: Config) -> None:
self.parameter_exchanger = self.get_parameter_exchanger(config)

self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())})

try:
self.set_early_stopper()
except NotImplementedError:
log(
INFO,
"""Early stopping not implemented for this client.
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
Override set_early_stopper to activate early stopper.""",
)
self.initialized = True

def set_early_stopper(self) -> None:
"""
User defined method that sets the early stopper for the client. To override this method, the user must
set self.early_stopper to an instance of EarlyStopper. The EarlyStopper class is defined in
fl4health.early_stopping. Example implementation:

from fl4health.utils.early_stopper import EarlyStopper
self.early_stopper = EarlyStopper(client=self, patience=3, interval_steps=100)

"""
raise NotImplementedError

def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
"""
Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.
Expand Down Expand Up @@ -1113,7 +1143,7 @@ def get_model(self, config: Config) -> nn.Module:
"""
raise NotImplementedError

def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | None:
def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None:
"""
Optional user defined method that returns learning rate scheduler
to be used throughout training for the given optimizer. Defaults to None.
Expand All @@ -1125,7 +1155,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler |
config (Config): The config from the server.

Returns:
_LRScheduler | None: Client learning rate schedulers.
LRScheduler | None: Client learning rate schedulers.
"""
return None

Expand Down
162 changes: 162 additions & 0 deletions fl4health/utils/early_stopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from collections.abc import Callable
from logging import INFO
from pathlib import Path
from typing import Any

import torch.nn as nn
from flwr.common.logger import log
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
from fl4health.clients.basic_client import BasicClient
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.logging import LoggingMode
from fl4health.utils.losses import TrainingLosses
from fl4health.utils.metrics import MetricManager
from fl4health.utils.snapshotter import (
LRSchedulerSnapshotter,
NumberSnapshotter,
OptimizerSnapshotter,
SerizableObjectSnapshotter,
Snapshotter,
T,
TorchModuleSnapshotter,
)


class EarlyStopper:
def __init__(
self,
client: BasicClient,
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
patience: int = 0,
interval_steps: int = 5,
snapshot_dir: Path | None = None,
) -> None:
"""
Early stopping class is an plugin for the client that allows to stop local training based on the validation
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
loss. At each training step this class saves the best state of the client and restores it if the client is
stopped. If the client starts to overfit, the early stopper will stop the training process and restore the best
state of the client before sending the model to the server.

Args:
client (BasicClient): The client to be monitored.
patience (int, optional): Number of steps to wait before stopping the training. If it is equal to 0 client
never stops, but still loads the best state before sending the model to the server. Defaults to 0.
interval_steps (int, optional): Determins how often the early stopper should check the validation loss.
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
Defaults to 5.
snapshot_dir (Path | None, optional): Rather than keeping best state in the memory we can checkpoint it to
the given directory. If it is not given, the best state is kept in the memory. Defaults to None.
"""

self.client = client

self.patience = patience
self.counte_down = patience
emersodb marked this conversation as resolved.
Show resolved Hide resolved
self.interval_steps = interval_steps

self.best_score: float | None = None
self.snapshot_ckpt: dict[str, Any] = {}

self.default_snapshot_attrs: dict = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can be more specific with the dict type here. At the very least, I think we can annotated it as
dict[str, tuple[Any, type[T]]]. It's possible that we can also do dict[str, tuple[SnapShotter[T], type[T]]] but I'm less certain that will come out right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is T generic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes T is generic and because of that I actually get an error even when I write tuple[Any, type[T]]].

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it more specific with tuple[Snapshotter, Any].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I'm good with the slightly stricter typing. Certain things with generics are always trickier.

sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
"model": (TorchModuleSnapshotter(self.client), nn.Module),
"optimizers": (OptimizerSnapshotter(self.client), Optimizer),
"lr_schedulers": (
LRSchedulerSnapshotter(self.client),
LRScheduler,
),
"learning_rate": (NumberSnapshotter(self.client), float),
"total_steps": (NumberSnapshotter(self.client), int),
"total_epochs": (NumberSnapshotter(self.client), int),
"reports_manager": (
SerizableObjectSnapshotter(self.client),
ReportsManager,
),
"train_loss_meter": (
SerizableObjectSnapshotter(self.client),
TrainingLosses,
),
"train_metric_manager": (
SerizableObjectSnapshotter(self.client),
MetricManager,
),
}

if snapshot_dir is not None:
self.checkpointer = PerRoundStateCheckpointer(snapshot_dir)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved

def add_default_snapshot_attr(
self, name: str, snapshot_class: Callable[[BasicClient], Snapshotter], input_type: type[T]
) -> None:
self.default_snapshot_attrs.update({name: (snapshot_class(self.client), input_type)})

def delete_default_snapshot_attr(self, name: str) -> None:
del self.default_snapshot_attrs[name]

def save_snapshot(self) -> None:
"""
Creats a snapshot of the client state and if snapshot_ckpt is given, saves it to the checkpoint.
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
"""
for attr, (snapshotter_function, expected_type) in self.default_snapshot_attrs.items():
self.snapshot_ckpt.update(snapshotter_function.save(attr, expected_type))

if self.checkpointer is not None:
self.checkpointer.save_checkpoint(f"temp_{self.client.client_name}.pt", self.snapshot_ckpt)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
self.snapshot_ckpt.clear()

log(
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
INFO,
f"""Saving client best state to checkpoint at {self.checkpointer.checkpoint_dir}
with name temp_{self.client.client_name}.pt""",
)

def load_snapshot(self, attrs: list[str]) -> None:
"""
Load checkpointed snapshot dict consisting to the respective model attributes.

Args:
args (list[str]): List of attributes to load from the checkpoint.
"""
assert (
self.checkpointer.checkpoint_exists(f"temp_{self.client.client_name}.pt") or self.snapshot_ckpt != {}
), "No checkpoint to load"

if self.checkpointer.checkpoint_exists(f"temp_{self.client.client_name}.pt"):
self.snapshot_ckpt = self.checkpointer.load_checkpoint(f"temp_{self.client.client_name}.pt")

for attr in attrs:
snapshotter_function, expected_type = self.default_snapshot_attrs[attr]
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
snapshotter_function.load(self.snapshot_ckpt, attr, expected_type)

def should_stop(self) -> bool:
"""
Determine if the client should stop training based on early stopping criteria.

Returns:
bool: True if training should stop, otherwise False.
"""

val_loss, _ = self.client._validate_or_test(
loader=self.client.val_loader,
loss_meter=self.client.val_loss_meter,
metric_manager=self.client.val_metric_manager,
logging_mode=LoggingMode.EARLY_STOP_VALIDATION,
include_losses_in_metrics=False,
)

if val_loss is None:
return False

if self.best_score is None or val_loss < self.best_score:
self.best_score = val_loss

self.count_down = self.patience
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
self.save_snapshot()
return False

self.count_down -= 1
if self.count_down == 0:
self.load_snapshot(list(self.default_snapshot_attrs.keys()))
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
return True

return False
1 change: 1 addition & 0 deletions fl4health/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

class LoggingMode(Enum):
emersodb marked this conversation as resolved.
Show resolved Hide resolved
TRAIN = "Training"
EARLY_STOP_VALIDATION = "Early_Stop_Validation"
VALIDATION = "Validation"
TEST = "Testing"
Loading
Loading