Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stop module #301

Merged
merged 31 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d0c454f
add early stopper
sanaAyrml Dec 5, 2024
8938bf5
Merge branch 'main' into sa_early_stop
sanaAyrml Dec 5, 2024
6136a90
test smoke tests
sanaAyrml Dec 5, 2024
6f20445
Separate early_stopper and snapshotter
sanaAyrml Dec 5, 2024
52546e9
Temporary commit
sanaAyrml Jan 2, 2025
fe610db
Merge branch 'sa_early_stop' of https://github.com/VectorInstitute/FL…
sanaAyrml Jan 2, 2025
92fe751
Firx extra early stopper implementation
sanaAyrml Jan 2, 2025
c73d04c
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 2, 2025
fe57eea
add seriazable snapshotter tests
sanaAyrml Jan 2, 2025
2d939b3
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
5627c72
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
311de1c
add metric manager snappshotter test
sanaAyrml Jan 8, 2025
0471905
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
27bf361
Resolve conflict with main
sanaAyrml Jan 8, 2025
3142889
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
f4ad580
update precommit type changes
sanaAyrml Jan 8, 2025
9d09429
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 8, 2025
00320fc
add snappshotter other tests
sanaAyrml Jan 9, 2025
5c48474
Add docstrings
sanaAyrml Jan 9, 2025
bf2504d
Add docstring
sanaAyrml Jan 9, 2025
476343e
Update doc strings
sanaAyrml Jan 9, 2025
dd3e964
Ignoring a vulnerability without a fix yet
emersodb Jan 9, 2025
080fd90
Address review comments
sanaAyrml Jan 23, 2025
a9fe020
Merge branch 'main' into sa_early_stop
sanaAyrml Jan 23, 2025
0a4d6b8
Update memory issues
sanaAyrml Jan 23, 2025
fda2946
Add str to logging mode
sanaAyrml Jan 23, 2025
14da12d
Address some comments
sanaAyrml Jan 24, 2025
a1aecfd
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2025
a0a8128
Bring some checks into should_stop function of early_stopper
sanaAyrml Jan 24, 2025
da842cd
Merge branch 'main' into sa_early_stop
emersodb Jan 24, 2025
171fcc1
Small typo fix
emersodb Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler
emersodb marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import DataLoader

from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule
Expand All @@ -27,6 +27,7 @@
set_pack_losses_with_val_metrics,
)
from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute
from fl4health.utils.early_stopper import EarlyStopper
from fl4health.utils.logging import LoggingMode
from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager
Expand Down Expand Up @@ -117,6 +118,11 @@ def __init__(
self.num_val_samples: int
self.num_test_samples: int | None
self.learning_rate: float | None

# User can set the early stopper for the client by instantiating the EarlyStopper class
# and setting the patience and interval_steps attributes. The early stopper will be used to
# stop training if the validation loss does not improve for a certain number of steps.
self.early_stopper: EarlyStopper | None = None
# Config can contain max_num_validation_steps key, which determines an upper bound
# for the validation steps taken. If not specified, no upper bound will be enforced.
# By specifying this in the config we cannot guarantee the validation set is the same
Expand Down Expand Up @@ -160,8 +166,16 @@ def get_parameters(self, config: Config) -> NDArrays:
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
assert self.model is not None and self.parameter_exchanger is not None
# If the client has early stopping module and the patience is None, we load the best saved state
# to send the best checkpointed local model's parameters to the server
self._maybe_load_saved_best_local_model_state()
return self.parameter_exchanger.push_parameters(self.model, config=config)

def _maybe_load_saved_best_local_model_state(self) -> None:
if self.early_stopper is not None and self.early_stopper.patience is None:
log(INFO, "Loading saved best model's state before sending model to server.")
emersodb marked this conversation as resolved.
Show resolved Hide resolved
self.early_stopper.load_snapshot(["model"])

def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
"""
Sets the local model parameters transferred from the server using a parameter exchanger to coordinate how
Expand Down Expand Up @@ -612,6 +626,7 @@ def train_by_epochs(
self.model.train()
steps_this_round = 0 # Reset number of steps this round
report_data: dict[str, Any] = {"round": current_round}
continue_training = True
emersodb marked this conversation as resolved.
Show resolved Hide resolved
for local_epoch in range(epochs):
self.train_metric_manager.clear()
self.train_loss_meter.clear()
Expand Down Expand Up @@ -641,6 +656,11 @@ def train_by_epochs(
self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps)
self.total_steps += 1
steps_this_round += 1
if self.early_stopper is not None and self.early_stopper.should_stop(steps_this_round):
log(INFO, "Early stopping criterion met. Stopping training.")
self.early_stopper.load_snapshot()
continue_training = False
break

# Log and report results
metrics = self.train_metric_manager.compute()
Expand All @@ -653,6 +673,9 @@ def train_by_epochs(
# Update internal epoch counter
self.total_epochs += 1

if not continue_training:
break

# Return final training metrics
return loss_dict, metrics

Expand Down Expand Up @@ -709,6 +732,10 @@ def train_by_steps(
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, None, self.total_steps)
self.total_steps += 1
if self.early_stopper is not None and self.early_stopper.should_stop(step):
log(INFO, "Early stopping criterion met. Stopping training.")
self.early_stopper.load_snapshot()
break

loss_dict = self.train_loss_meter.compute().as_dict()
metrics = self.train_metric_manager.compute()
Expand Down Expand Up @@ -879,7 +906,6 @@ def setup_client(self, config: Config) -> None:
self.parameter_exchanger = self.get_parameter_exchanger(config)

self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())})

self.initialized = True

def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
Expand Down Expand Up @@ -1113,7 +1139,7 @@ def get_model(self, config: Config) -> nn.Module:
"""
raise NotImplementedError

def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | None:
def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None:
"""
Optional user defined method that returns learning rate scheduler
to be used throughout training for the given optimizer. Defaults to None.
Expand All @@ -1125,7 +1151,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler |
config (Config): The config from the server.

Returns:
_LRScheduler | None: Client learning rate schedulers.
LRScheduler | None: Client learning rate schedulers.
"""
return None

Expand Down
188 changes: 188 additions & 0 deletions fl4health/utils/early_stopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

import copy
from collections.abc import Callable
from logging import INFO, WARNING
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch.nn as nn
from flwr.common.logger import log
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.logging import LoggingMode
from fl4health.utils.losses import TrainingLosses
from fl4health.utils.metrics import MetricManager
from fl4health.utils.snapshotter import (
AbstractSnapshotter,
LRSchedulerSnapshotter,
NumberSnapshotter,
OptimizerSnapshotter,
SerializableObjectSnapshotter,
T,
TorchModuleSnapshotter,
)

if TYPE_CHECKING:
from fl4health.clients.basic_client import BasicClient
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved


class EarlyStopper:
def __init__(
self,
client: BasicClient,
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
patience: int | None = 1,
interval_steps: int = 5,
snapshot_dir: Path | None = None,
) -> None:
"""
Early stopping class is a plugin for the client that allows to stop local training based on the validation
loss. At each training step this class saves the best state of the client and restores it if the client is
stopped. If the client starts to overfit, the early stopper will stop the training process and restore the best
state of the client before sending the model to the server.

Args:
client (BasicClient): The client to be monitored.
patience (int, optional): Number of validation cycles to wait before stopping the training. If it is equal
to None client never stops, but still loads the best state before sending the model to the server.
Defaults to 1.
interval_steps (int): Specifies the frequency, in terms of training intervals, at which the early
stopping mechanism should evaluate the validation loss. Defaults to 5.
snapshot_dir (Path | None, optional): Rather than keeping best state in the memory we can checkpoint it to
the given directory. If it is not given, the best state is kept in the memory. Defaults to None.
"""

self.client = client

self.patience = patience
self.count_down = patience
self.interval_steps = interval_steps

self.best_score: float | None = None
self.snapshot_ckpt: dict[str, tuple[AbstractSnapshotter, Any]] = {}

self.snapshot_attrs: dict = {
"model": (TorchModuleSnapshotter(self.client), nn.Module),
"optimizers": (OptimizerSnapshotter(self.client), Optimizer),
"lr_schedulers": (
LRSchedulerSnapshotter(self.client),
LRScheduler,
),
"learning_rate": (NumberSnapshotter(self.client), float),
"total_steps": (NumberSnapshotter(self.client), int),
"total_epochs": (NumberSnapshotter(self.client), int),
"reports_manager": (
SerializableObjectSnapshotter(self.client),
ReportsManager,
),
"train_loss_meter": (
SerializableObjectSnapshotter(self.client),
TrainingLosses,
),
"train_metric_manager": (
SerializableObjectSnapshotter(self.client),
MetricManager,
),
}

if snapshot_dir is not None:
# TODO: Move to generic checkpointer
self.checkpointer = PerRoundStateCheckpointer(snapshot_dir)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
self.checkpoint_name = f"temp_{self.client.client_name}.pt"
else:
log(INFO, "Snapshot is being persisted in memory")

def add_default_snapshot_attr(
self, name: str, snapshot_class: Callable[[BasicClient], AbstractSnapshotter], input_type: type[T]
) -> None:
self.snapshot_attrs.update({name: (snapshot_class(self.client), input_type)})

def delete_default_snapshot_attr(self, name: str) -> None:
del self.snapshot_attrs[name]

def save_snapshot(self) -> None:
"""
Creates a snapshot of the client state and if snapshot_ckpt is given, saves it to the checkpoint.
"""
for attr, (snapshotter_function, expected_type) in self.snapshot_attrs.items():
self.snapshot_ckpt.update(snapshotter_function.save(attr, expected_type))

if self.checkpointer is not None:
log(
INFO,
f"Saving client best state to checkpoint at {self.checkpointer.checkpoint_dir} "
f"with name {self.checkpoint_name}.",
)
self.checkpointer.save_checkpoint(self.checkpoint_name, self.snapshot_ckpt)
self.snapshot_ckpt.clear()

else:
log(
WARNING,
"Checkpointing directory is not provided. Client best state will be kept in the memory.",
)
self.snapshot_ckpt = copy.deepcopy(self.snapshot_ckpt)

def load_snapshot(self, attributes: list[str] | None = None) -> None:
"""
Load checkpointed snapshot dict consisting to the respective model attributes.

Args:
attributes (list[str] | None): List of attributes to load from the checkpoint.
If None, all attributes are loaded. Defaults to None.
"""
assert (
self.checkpointer.checkpoint_exists(self.checkpoint_name) or self.snapshot_ckpt != {}
), "No checkpoint to load"

if attributes is None:
attributes = list(self.snapshot_attrs.keys())

log(INFO, f"Loading client best state {attributes} from checkpoint at {self.checkpointer.checkpoint_dir}")

if self.checkpointer.checkpoint_exists(self.checkpoint_name):
self.snapshot_ckpt = self.checkpointer.load_checkpoint(self.checkpoint_name)

for attr in attributes:
snapshotter, expected_type = self.snapshot_attrs[attr]
snapshotter.load(self.snapshot_ckpt, attr, expected_type)

def should_stop(self, steps: int) -> bool:
"""
Determine if the client should stop training based on early stopping criteria.

Args:
steps (int): Number of steps since the start of the training.

Returns:
bool: True if training should stop, otherwise False.
"""
if steps % self.interval_steps != 0:
return False

val_loss, _ = self.client._validate_or_test(
loader=self.client.val_loader,
loss_meter=self.client.val_loss_meter,
metric_manager=self.client.val_metric_manager,
logging_mode=LoggingMode.EARLY_STOP_VALIDATION,
include_losses_in_metrics=False,
)

if val_loss is None:
return False

if self.best_score is None or val_loss < self.best_score:
self.best_score = val_loss
self.count_down = self.patience
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
self.save_snapshot()
return False

if self.count_down is not None:
self.count_down -= 1
if self.count_down <= 0:
return True

return False
3 changes: 2 additions & 1 deletion fl4health/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum


class LoggingMode(Enum):
class LoggingMode(str, Enum):
TRAIN = "Training"
EARLY_STOP_VALIDATION = "Early_Stop_Validation"
VALIDATION = "Validation"
TEST = "Testing"
Loading
Loading