diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index bc38c33d8..c4dc6e33d 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -11,7 +11,7 @@ from flwr.common.typing import Config, NDArrays, Scalar from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule @@ -890,17 +890,28 @@ def setup_client(self, config: Config) -> None: self.parameter_exchanger = self.get_parameter_exchanger(config) self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())}) + try: + self.set_early_stopper() + except NotImplementedError: + log( + INFO, + """Early stopping not implemented for this client. + Override set_early_stopper to activate early stopper.""", + ) self.initialized = True - def setup_early_stopper( - self, - patience: int = -1, - interval_steps: int = 5, - snapshot_dir: Path | None = None, - ) -> None: - from fl4health.utils.early_stopper import EarlyStopper + def set_early_stopper(self) -> None: + """ + User defined method that sets the early stopper for the client. To override this method, the user must + set self.early_stopper to an instance of EarlyStopper. The EarlyStopper class is defined in + fl4health.early_stopping. Example implementation: - self.early_stopper = EarlyStopper(self, patience, interval_steps, snapshot_dir) + ⁠ python + from fl4health.utils.early_stopper import EarlyStopper + self.early_stopper = EarlyStopper(client=self, patience=3, interval_steps=100) +  ⁠ + """ + raise NotImplementedError def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: """ @@ -1133,7 +1144,7 @@ def get_model(self, config: Config) -> nn.Module: """ raise NotImplementedError - def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | None: + def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None: """ Optional user defined method that returns learning rate scheduler to be used throughout training for the given optimizer. Defaults to None. @@ -1145,7 +1156,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | config (Config): The config from the server. Returns: - _LRScheduler | None: Client learning rate schedulers. + LRScheduler | None: Client learning rate schedulers. """ return None diff --git a/fl4health/utils/early_stopper.py b/fl4health/utils/early_stopper.py index baa7eb2e5..02045cfa1 100644 --- a/fl4health/utils/early_stopper.py +++ b/fl4health/utils/early_stopper.py @@ -6,7 +6,7 @@ import torch.nn as nn from flwr.common.logger import log from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import LRScheduler from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.clients.basic_client import BasicClient @@ -63,7 +63,7 @@ def __init__( "optimizers": (OptimizerSnapshotter(self.client), Optimizer), # dict of optimizers we only need state_dict "lr_schedulers": ( LRSchedulerSnapshotter(self.client), - _LRScheduler, + LRScheduler, ), # dict of schedulers we only need state_dict "learning_rate": (NumberSnapshotter(self.client), float), # number we can copy "total_steps": (NumberSnapshotter(self.client), int), # number we can copy @@ -99,7 +99,7 @@ def save_snapshot(self) -> None: metrics reporter and optimizers state. Method can be overridden to augment saved checkpointed state. """ for arg, (snapshotter_function, expected_type) in self.default_snapshot_args.items(): - self.snapshot_ckpt[arg] = snapshotter_function.save(arg, expected_type) + self.snapshot_ckpt.update(snapshotter_function.save(arg, expected_type)) if self.checkpointer is not None: self.checkpointer.save_checkpoint(f"temp_{self.client.client_name}.pt", self.snapshot_ckpt) diff --git a/fl4health/utils/snapshotter.py b/fl4health/utils/snapshotter.py index 79daff685..5a5feb925 100644 --- a/fl4health/utils/snapshotter.py +++ b/fl4health/utils/snapshotter.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import LRScheduler from fl4health.clients.basic_client import BasicClient from fl4health.reporting.reports_manager import ReportsManager @@ -30,9 +30,9 @@ def dict_wrap_attr(self, name: str, expected_type: type[T]) -> dict[str, T]: else: raise ValueError(f"Uncompatible type of attribute {type(attribute)}") - def save(self, name: str, expected_type: type[T]) -> Any: + def save(self, name: str, expected_type: type[T]) -> dict[str, Any]: attribute = self.dict_wrap_attr(name, expected_type) - return self.save_attribute(attribute) + return {name: self.save_attribute(attribute)} def load(self, ckpt: dict[str, Any], name: str, expected_type: type[T]) -> None: attribute = self.dict_wrap_attr(name, expected_type) @@ -69,9 +69,9 @@ def load_attribute(self, attribute_ckpt: dict[str, Any], attribute: dict[str, Op optimizer.load_state_dict(optimizer_state_dict) -class LRSchedulerSnapshotter(Snapshotter[_LRScheduler]): +class LRSchedulerSnapshotter(Snapshotter[LRScheduler]): - def save_attribute(self, attribute: dict[str, _LRScheduler]) -> dict[str, Any]: + def save_attribute(self, attribute: dict[str, LRScheduler]) -> dict[str, Any]: """ Save the state of the optimizers (either single or dictionary of them). """ @@ -80,7 +80,7 @@ def save_attribute(self, attribute: dict[str, _LRScheduler]) -> dict[str, Any]: output[key] = lr_scheduler.state_dict() return output - def load_attribute(self, attribute_ckpt: dict[str, Any], attribute: dict[str, _LRScheduler]) -> None: + def load_attribute(self, attribute_ckpt: dict[str, Any], attribute: dict[str, LRScheduler]) -> None: for key, lr_scheduler in attribute.items(): lr_scheduler.load_state_dict(attribute_ckpt[key]) diff --git a/tests/utils/snapshotter_test.py b/tests/utils/snapshotter_test.py index 69c314ae7..d15676b9d 100644 --- a/tests/utils/snapshotter_test.py +++ b/tests/utils/snapshotter_test.py @@ -8,24 +8,100 @@ from fl4health.reporting.reports_manager import ReportsManager from fl4health.utils.losses import LossMeter, TrainingLosses from fl4health.utils.metrics import Accuracy, MetricManager -from fl4health.utils.snapshotter import SerizableObjectSnapshotter +from fl4health.utils.snapshotter import ( + LRSchedulerSnapshotter, + NumberSnapshotter, + OptimizerSnapshotter, + SerizableObjectSnapshotter, + TorchModuleSnapshotter, +) from fl4health.utils.typing import TorchPredType, TorchTargetType +from tests.test_utils.models_for_test import SingleLayerWithSeed + + +def test_number_snapshotter() -> None: + metrics = [Accuracy("accuracy")] + reporter = JsonReporter() + fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter]) + old_total_steps = fl_client.total_steps + number_snapshotter = NumberSnapshotter(fl_client) + sp = number_snapshotter.save("total_steps", int) + fl_client.total_steps += 1 + assert sp["total_steps"] == {"None": old_total_steps} + assert fl_client.total_steps != old_total_steps + number_snapshotter.load(sp, "total_steps", int) + assert fl_client.total_steps == old_total_steps + + +def test_optimizer_scheduler_model_snapshotter() -> None: + metrics = [Accuracy("accuracy")] + reporter = JsonReporter() + fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter]) + fl_client.model = SingleLayerWithSeed() + fl_client.criterion = torch.nn.CrossEntropyLoss() + + input_data = torch.randn(32, 100) # Batch size = 32, Input size = 10 + target_data = torch.randn(32, 2) # Batch size = 32, Target size = 1 + + fl_client.optimizers = {"global": torch.optim.SGD(fl_client.model.parameters(), lr=0.001)} + fl_client.lr_schedulers = { + "global": torch.optim.lr_scheduler.StepLR(fl_client.optimizers["global"], step_size=30, gamma=0.1) + } + old_optimizers = copy.deepcopy(fl_client.optimizers) + old_lr_schedulers = copy.deepcopy(fl_client.lr_schedulers) + old_model = copy.deepcopy(fl_client.model) + + optimizer_snapshotter = OptimizerSnapshotter(fl_client) + lr_scheduler_snapshotter = LRSchedulerSnapshotter(fl_client) + model_snapshotter = TorchModuleSnapshotter(fl_client) + + snapshots = {} + snapshots.update(optimizer_snapshotter.save("optimizers", torch.optim.Optimizer)) + snapshots.update(lr_scheduler_snapshotter.save("lr_schedulers", torch.optim.lr_scheduler.LRScheduler)) + snapshots.update(model_snapshotter.save("model", torch.nn.Module)) + + fl_client.train_step(input_data, target_data) + + fl_client.optimizers["global"].step() # Update model weights + fl_client.lr_schedulers["global"].step() + + for key, value in fl_client.model.state_dict().items(): + assert not torch.equal(value, old_model.state_dict()[key]) + + for key, optimizers in fl_client.optimizers.items(): + assert optimizers.state_dict()["state"] != old_optimizers[key].state_dict()["state"] + + for key, schedulers in fl_client.lr_schedulers.items(): + assert schedulers.state_dict() != old_lr_schedulers[key].state_dict() + + optimizer_snapshotter.load(snapshots, "optimizers", torch.optim.Optimizer) + lr_scheduler_snapshotter.load(snapshots, "lr_schedulers", torch.optim.lr_scheduler.LRScheduler) + model_snapshotter.load(snapshots, "model", torch.nn.Module) + + for key, value in fl_client.model.state_dict().items(): + assert torch.equal(value, old_model.state_dict()[key]) + + for key, optimizers in fl_client.optimizers.items(): + assert optimizers.state_dict()["state"] == old_optimizers[key].state_dict()["state"] + + for key, schedulers in fl_client.lr_schedulers.items(): + assert schedulers.state_dict() == old_lr_schedulers[key].state_dict() def test_loss_meter_snapshotter() -> None: metrics = [Accuracy("accuracy")] reporter = JsonReporter() fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter]) - ckpt = {} + snapshots = {} fl_client.train_loss_meter.update(TrainingLosses(backward=torch.Tensor([35]), additional_losses=None)) snapshotter = SerizableObjectSnapshotter(fl_client) - ckpt["train_loss_meter"] = snapshotter.save("train_loss_meter", LossMeter) + snapshots.update(snapshotter.save("train_loss_meter", LossMeter)) old_loss_meter = copy.deepcopy(fl_client.train_loss_meter) fl_client.train_loss_meter.update(TrainingLosses(backward=torch.Tensor([10]), additional_losses=None)) assert len(old_loss_meter.losses_list) != len(fl_client.train_loss_meter.losses_list) - snapshotter.load(ckpt, "train_loss_meter", LossMeter) + snapshotter.load(snapshots, "train_loss_meter", LossMeter) assert len(old_loss_meter.losses_list) == len(fl_client.train_loss_meter.losses_list) for i in range(len(fl_client.train_loss_meter.losses_list)): @@ -40,11 +116,11 @@ def test_reports_manager_snapshotter() -> None: metrics = [Accuracy("accuracy")] reporter = JsonReporter() fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter]) - ckpt = {} + snapshots = {} fl_client.reports_manager.report({"start": "2012-12-12 12:12:10"}) snapshotter = SerizableObjectSnapshotter(fl_client) - ckpt["reports_manager"] = snapshotter.save("reports_manager", ReportsManager) + snapshots.update(snapshotter.save("reports_manager", ReportsManager)) old_reports_manager = copy.deepcopy(fl_client.reports_manager) fl_client.reports_manager.report({"shutdown": "2012-12-12 12:12:12"}) @@ -54,7 +130,7 @@ def test_reports_manager_snapshotter() -> None: assert old_reports_manager.reporters[0].metrics != fl_client.reports_manager.reporters[0].metrics - snapshotter.load(ckpt, "reports_manager", ReportsManager) + snapshotter.load(snapshots, "reports_manager", ReportsManager) assert old_reports_manager.reporters[0].metrics == fl_client.reports_manager.reporters[0].metrics @@ -62,7 +138,7 @@ def test_metric_manager_snapshotter() -> None: metrics = [Accuracy("accuracy")] reporter = JsonReporter() fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter]) - ckpt = {} + snapshots = {} preds: TorchPredType = { "1": torch.tensor([0.7369, 0.5121, 0.2674, 0.5847, 0.4032, 0.7458, 0.9274, 0.3258, 0.7095, 0.0513]) } @@ -70,7 +146,7 @@ def test_metric_manager_snapshotter() -> None: fl_client.train_metric_manager.update(preds, target) snapshotter = SerizableObjectSnapshotter(fl_client) - ckpt["train_metric_manager"] = snapshotter.save("train_metric_manager", MetricManager) + snapshots.update(snapshotter.save("train_metric_manager", MetricManager)) old_train_metric_manager = copy.deepcopy(fl_client.train_metric_manager) fl_client.train_metric_manager.update(preds, target) assert isinstance(fl_client.train_metric_manager.metrics_per_prediction_type["1"][0], Accuracy) and isinstance( @@ -83,7 +159,7 @@ def test_metric_manager_snapshotter() -> None: old_train_metric_manager.metrics_per_prediction_type["1"][0].accumulated_targets ) - snapshotter.load(ckpt, "train_metric_manager", MetricManager) + snapshotter.load(snapshots, "train_metric_manager", MetricManager) assert len(fl_client.train_metric_manager.metrics_per_prediction_type["1"][0].accumulated_inputs) == len( old_train_metric_manager.metrics_per_prediction_type["1"][0].accumulated_inputs )