Skip to content

Commit

Permalink
add snappshotter other tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Jan 9, 2025
1 parent 9d09429 commit 00320fc
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 30 deletions.
33 changes: 22 additions & 11 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule
Expand Down Expand Up @@ -890,17 +890,28 @@ def setup_client(self, config: Config) -> None:
self.parameter_exchanger = self.get_parameter_exchanger(config)

self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())})
try:
self.set_early_stopper()
except NotImplementedError:
log(
INFO,
"""Early stopping not implemented for this client.
Override set_early_stopper to activate early stopper.""",
)
self.initialized = True

def setup_early_stopper(
self,
patience: int = -1,
interval_steps: int = 5,
snapshot_dir: Path | None = None,
) -> None:
from fl4health.utils.early_stopper import EarlyStopper
def set_early_stopper(self) -> None:
"""
User defined method that sets the early stopper for the client. To override this method, the user must
set self.early_stopper to an instance of EarlyStopper. The EarlyStopper class is defined in
fl4health.early_stopping. Example implementation:
self.early_stopper = EarlyStopper(self, patience, interval_steps, snapshot_dir)
⁠ python
from fl4health.utils.early_stopper import EarlyStopper
self.early_stopper = EarlyStopper(client=self, patience=3, interval_steps=100)
 ⁠
"""
raise NotImplementedError

def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
"""
Expand Down Expand Up @@ -1133,7 +1144,7 @@ def get_model(self, config: Config) -> nn.Module:
"""
raise NotImplementedError

def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | None:
def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None:
"""
Optional user defined method that returns learning rate scheduler
to be used throughout training for the given optimizer. Defaults to None.
Expand All @@ -1145,7 +1156,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler |
config (Config): The config from the server.
Returns:
_LRScheduler | None: Client learning rate schedulers.
LRScheduler | None: Client learning rate schedulers.
"""
return None

Expand Down
6 changes: 3 additions & 3 deletions fl4health/utils/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
from flwr.common.logger import log
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
from fl4health.clients.basic_client import BasicClient
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
"optimizers": (OptimizerSnapshotter(self.client), Optimizer), # dict of optimizers we only need state_dict
"lr_schedulers": (
LRSchedulerSnapshotter(self.client),
_LRScheduler,
LRScheduler,
), # dict of schedulers we only need state_dict
"learning_rate": (NumberSnapshotter(self.client), float), # number we can copy
"total_steps": (NumberSnapshotter(self.client), int), # number we can copy
Expand Down Expand Up @@ -99,7 +99,7 @@ def save_snapshot(self) -> None:
metrics reporter and optimizers state. Method can be overridden to augment saved checkpointed state.
"""
for arg, (snapshotter_function, expected_type) in self.default_snapshot_args.items():
self.snapshot_ckpt[arg] = snapshotter_function.save(arg, expected_type)
self.snapshot_ckpt.update(snapshotter_function.save(arg, expected_type))

if self.checkpointer is not None:
self.checkpointer.save_checkpoint(f"temp_{self.client.client_name}.pt", self.snapshot_ckpt)
Expand Down
12 changes: 6 additions & 6 deletions fl4health/utils/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler

from fl4health.clients.basic_client import BasicClient
from fl4health.reporting.reports_manager import ReportsManager
Expand All @@ -30,9 +30,9 @@ def dict_wrap_attr(self, name: str, expected_type: type[T]) -> dict[str, T]:
else:
raise ValueError(f"Uncompatible type of attribute {type(attribute)}")

def save(self, name: str, expected_type: type[T]) -> Any:
def save(self, name: str, expected_type: type[T]) -> dict[str, Any]:
attribute = self.dict_wrap_attr(name, expected_type)
return self.save_attribute(attribute)
return {name: self.save_attribute(attribute)}

def load(self, ckpt: dict[str, Any], name: str, expected_type: type[T]) -> None:
attribute = self.dict_wrap_attr(name, expected_type)
Expand Down Expand Up @@ -69,9 +69,9 @@ def load_attribute(self, attribute_ckpt: dict[str, Any], attribute: dict[str, Op
optimizer.load_state_dict(optimizer_state_dict)


class LRSchedulerSnapshotter(Snapshotter[_LRScheduler]):
class LRSchedulerSnapshotter(Snapshotter[LRScheduler]):

def save_attribute(self, attribute: dict[str, _LRScheduler]) -> dict[str, Any]:
def save_attribute(self, attribute: dict[str, LRScheduler]) -> dict[str, Any]:
"""
Save the state of the optimizers (either single or dictionary of them).
"""
Expand All @@ -80,7 +80,7 @@ def save_attribute(self, attribute: dict[str, _LRScheduler]) -> dict[str, Any]:
output[key] = lr_scheduler.state_dict()
return output

def load_attribute(self, attribute_ckpt: dict[str, Any], attribute: dict[str, _LRScheduler]) -> None:
def load_attribute(self, attribute_ckpt: dict[str, Any], attribute: dict[str, LRScheduler]) -> None:
for key, lr_scheduler in attribute.items():
lr_scheduler.load_state_dict(attribute_ckpt[key])

Expand Down
96 changes: 86 additions & 10 deletions tests/utils/snapshotter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,100 @@
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.losses import LossMeter, TrainingLosses
from fl4health.utils.metrics import Accuracy, MetricManager
from fl4health.utils.snapshotter import SerizableObjectSnapshotter
from fl4health.utils.snapshotter import (
LRSchedulerSnapshotter,
NumberSnapshotter,
OptimizerSnapshotter,
SerizableObjectSnapshotter,
TorchModuleSnapshotter,
)
from fl4health.utils.typing import TorchPredType, TorchTargetType
from tests.test_utils.models_for_test import SingleLayerWithSeed


def test_number_snapshotter() -> None:
metrics = [Accuracy("accuracy")]
reporter = JsonReporter()
fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter])
old_total_steps = fl_client.total_steps
number_snapshotter = NumberSnapshotter(fl_client)
sp = number_snapshotter.save("total_steps", int)
fl_client.total_steps += 1
assert sp["total_steps"] == {"None": old_total_steps}
assert fl_client.total_steps != old_total_steps
number_snapshotter.load(sp, "total_steps", int)
assert fl_client.total_steps == old_total_steps


def test_optimizer_scheduler_model_snapshotter() -> None:
metrics = [Accuracy("accuracy")]
reporter = JsonReporter()
fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter])
fl_client.model = SingleLayerWithSeed()
fl_client.criterion = torch.nn.CrossEntropyLoss()

input_data = torch.randn(32, 100) # Batch size = 32, Input size = 10
target_data = torch.randn(32, 2) # Batch size = 32, Target size = 1

fl_client.optimizers = {"global": torch.optim.SGD(fl_client.model.parameters(), lr=0.001)}
fl_client.lr_schedulers = {
"global": torch.optim.lr_scheduler.StepLR(fl_client.optimizers["global"], step_size=30, gamma=0.1)
}
old_optimizers = copy.deepcopy(fl_client.optimizers)
old_lr_schedulers = copy.deepcopy(fl_client.lr_schedulers)
old_model = copy.deepcopy(fl_client.model)

optimizer_snapshotter = OptimizerSnapshotter(fl_client)
lr_scheduler_snapshotter = LRSchedulerSnapshotter(fl_client)
model_snapshotter = TorchModuleSnapshotter(fl_client)

snapshots = {}
snapshots.update(optimizer_snapshotter.save("optimizers", torch.optim.Optimizer))
snapshots.update(lr_scheduler_snapshotter.save("lr_schedulers", torch.optim.lr_scheduler.LRScheduler))
snapshots.update(model_snapshotter.save("model", torch.nn.Module))

fl_client.train_step(input_data, target_data)

fl_client.optimizers["global"].step() # Update model weights
fl_client.lr_schedulers["global"].step()

for key, value in fl_client.model.state_dict().items():
assert not torch.equal(value, old_model.state_dict()[key])

for key, optimizers in fl_client.optimizers.items():
assert optimizers.state_dict()["state"] != old_optimizers[key].state_dict()["state"]

for key, schedulers in fl_client.lr_schedulers.items():
assert schedulers.state_dict() != old_lr_schedulers[key].state_dict()

optimizer_snapshotter.load(snapshots, "optimizers", torch.optim.Optimizer)
lr_scheduler_snapshotter.load(snapshots, "lr_schedulers", torch.optim.lr_scheduler.LRScheduler)
model_snapshotter.load(snapshots, "model", torch.nn.Module)

for key, value in fl_client.model.state_dict().items():
assert torch.equal(value, old_model.state_dict()[key])

for key, optimizers in fl_client.optimizers.items():
assert optimizers.state_dict()["state"] == old_optimizers[key].state_dict()["state"]

for key, schedulers in fl_client.lr_schedulers.items():
assert schedulers.state_dict() == old_lr_schedulers[key].state_dict()


def test_loss_meter_snapshotter() -> None:
metrics = [Accuracy("accuracy")]
reporter = JsonReporter()
fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter])
ckpt = {}
snapshots = {}

fl_client.train_loss_meter.update(TrainingLosses(backward=torch.Tensor([35]), additional_losses=None))
snapshotter = SerizableObjectSnapshotter(fl_client)
ckpt["train_loss_meter"] = snapshotter.save("train_loss_meter", LossMeter)
snapshots.update(snapshotter.save("train_loss_meter", LossMeter))
old_loss_meter = copy.deepcopy(fl_client.train_loss_meter)
fl_client.train_loss_meter.update(TrainingLosses(backward=torch.Tensor([10]), additional_losses=None))
assert len(old_loss_meter.losses_list) != len(fl_client.train_loss_meter.losses_list)

snapshotter.load(ckpt, "train_loss_meter", LossMeter)
snapshotter.load(snapshots, "train_loss_meter", LossMeter)

assert len(old_loss_meter.losses_list) == len(fl_client.train_loss_meter.losses_list)
for i in range(len(fl_client.train_loss_meter.losses_list)):
Expand All @@ -40,11 +116,11 @@ def test_reports_manager_snapshotter() -> None:
metrics = [Accuracy("accuracy")]
reporter = JsonReporter()
fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter])
ckpt = {}
snapshots = {}

fl_client.reports_manager.report({"start": "2012-12-12 12:12:10"})
snapshotter = SerizableObjectSnapshotter(fl_client)
ckpt["reports_manager"] = snapshotter.save("reports_manager", ReportsManager)
snapshots.update(snapshotter.save("reports_manager", ReportsManager))
old_reports_manager = copy.deepcopy(fl_client.reports_manager)
fl_client.reports_manager.report({"shutdown": "2012-12-12 12:12:12"})

Expand All @@ -54,23 +130,23 @@ def test_reports_manager_snapshotter() -> None:

assert old_reports_manager.reporters[0].metrics != fl_client.reports_manager.reporters[0].metrics

snapshotter.load(ckpt, "reports_manager", ReportsManager)
snapshotter.load(snapshots, "reports_manager", ReportsManager)
assert old_reports_manager.reporters[0].metrics == fl_client.reports_manager.reporters[0].metrics


def test_metric_manager_snapshotter() -> None:
metrics = [Accuracy("accuracy")]
reporter = JsonReporter()
fl_client = BasicClient(data_path=Path(""), metrics=metrics, device=torch.device(0), reporters=[reporter])
ckpt = {}
snapshots = {}
preds: TorchPredType = {
"1": torch.tensor([0.7369, 0.5121, 0.2674, 0.5847, 0.4032, 0.7458, 0.9274, 0.3258, 0.7095, 0.0513])
}
target: TorchTargetType = {"1": torch.tensor([0, 1, 0, 1, 1, 0, 1, 1, 0, 1])}

fl_client.train_metric_manager.update(preds, target)
snapshotter = SerizableObjectSnapshotter(fl_client)
ckpt["train_metric_manager"] = snapshotter.save("train_metric_manager", MetricManager)
snapshots.update(snapshotter.save("train_metric_manager", MetricManager))
old_train_metric_manager = copy.deepcopy(fl_client.train_metric_manager)
fl_client.train_metric_manager.update(preds, target)
assert isinstance(fl_client.train_metric_manager.metrics_per_prediction_type["1"][0], Accuracy) and isinstance(
Expand All @@ -83,7 +159,7 @@ def test_metric_manager_snapshotter() -> None:
old_train_metric_manager.metrics_per_prediction_type["1"][0].accumulated_targets
)

snapshotter.load(ckpt, "train_metric_manager", MetricManager)
snapshotter.load(snapshots, "train_metric_manager", MetricManager)
assert len(fl_client.train_metric_manager.metrics_per_prediction_type["1"][0].accumulated_inputs) == len(
old_train_metric_manager.metrics_per_prediction_type["1"][0].accumulated_inputs
)
Expand Down

0 comments on commit 00320fc

Please sign in to comment.