Skip to content

Commit

Permalink
Merge pull request #263 from VectorInstitute/update-model-merge-repor…
Browse files Browse the repository at this point in the history
…ters

Make necessary changes to model merge server and client with new repo…
  • Loading branch information
emersodb authored Oct 25, 2024
2 parents 562199f + fbdc1f1 commit 53431b7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ jobs:
# Fix is 43.0.1 but flwr 1.9 depends on < 43
ignore-vulns: |
GHSA-h4gh-qq45-vh27
GHSA-q34m-jh98-gwm2
GHSA-f9vj-2wh5-fj8j
36 changes: 17 additions & 19 deletions fl4health/clients/model_merge_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.reporting.metrics import MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.metrics import Metric, MetricManager
from fl4health.utils.random import generate_hash
from fl4health.utils.typing import TorchInputType, TorchTargetType
Expand All @@ -24,7 +25,8 @@ def __init__(
model_path: Path,
metrics: Sequence[Metric],
device: torch.device,
metrics_reporter: Optional[MetricsReporter] = None,
reporters: Optional[Sequence[BaseReporter]] = None,
client_name: Optional[str] = None,
) -> None:
"""
ModelMergeClient to support functionality to simply perform model merging across client
Expand All @@ -36,23 +38,23 @@ def __init__(
metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model
device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or
'cuda'
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
during the execution. Defaults to an instance of MetricsReporter with default init parameters.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the client should send data to.
client_name (str): An optional client name that uniquely identifies a client.
If not passed, a hash is randomly generated.
"""
self.data_path = data_path
self.model_path = model_path
self.metrics = metrics
self.device = device
self.metrics_reporter = metrics_reporter
self.client_name = client_name if client_name is not None else generate_hash()

self.initialized = False
self.client_name = generate_hash()
self.test_metric_manager = MetricManager(metrics=self.metrics, metric_manager_name="test")

if metrics_reporter is not None:
self.metrics_reporter = metrics_reporter
else:
self.metrics_reporter = MetricsReporter(run_id=self.client_name)
# Initialize reporters with client information.
self.reports_manager = ReportsManager(reporters)
self.reports_manager.initialize(id=self.client_name)

self.model: nn.Module
self.test_loader: DataLoader
Expand Down Expand Up @@ -133,19 +135,15 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
"""
assert not self.initialized
self.setup_client(config)
assert self.metrics_reporter is not None
self.metrics_reporter.add_to_metrics_at_round(
1,
data={"fit_start": datetime.datetime.now()},

self.reports_manager.report(
data={"host_type": "client", "fit_start": datetime.datetime.now()},
)

val_metrics = self.validate()

self.metrics_reporter.add_to_metrics_at_round(
1,
data={
"fit_metrics": val_metrics,
},
self.reports_manager.report(
data={"fit_metrics": val_metrics, "host_type": "client", "fit_end": datetime.datetime.now()},
)

return self.get_parameters(config), self.num_test_samples, val_metrics
Expand Down
28 changes: 17 additions & 11 deletions fl4health/server/model_merge_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import timeit
from logging import INFO, WARNING
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Sequence, Tuple

import torch.nn as nn
from flwr.common.logger import log
Expand All @@ -14,8 +14,10 @@

from fl4health.checkpointing.checkpointer import LatestTorchCheckpointer
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.reporting.metrics import MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.strategies.model_merge_strategy import ModelMergeStrategy
from fl4health.utils.random import generate_hash


class ModelMergeServer(Server):
Expand All @@ -27,7 +29,8 @@ def __init__(
checkpointer: Optional[LatestTorchCheckpointer] = None,
server_model: Optional[nn.Module] = None,
parameter_exchanger: Optional[ParameterExchanger] = None,
metrics_reporter: Optional[MetricsReporter] = None,
reporters: Sequence[BaseReporter] | None = None,
server_name: Optional[str] = None,
) -> None:
"""
ModelMergeServer provides functionality to fetch client weights, perform a simple average,
Expand All @@ -44,8 +47,9 @@ def __init__(
server side checkpointing. Must only be provided if checkpointer is also provided. Defaults to None.
parameter_exchanger (Optional[ExchangerType]): Optional parameter exchanger to be used to hydrate the
model. Only used if checkpointer and model are also not None. Defaults to None.
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
during the execution. Defaults to an instance of MetricsReporter with default init parameters.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should
send data to before and after each round.
server_name (Optional[str]): An optional string name to uniquely identify server.
"""
assert isinstance(strategy, ModelMergeStrategy)
assert (server_model is None and checkpointer is None and parameter_exchanger is None) or (
Expand All @@ -56,11 +60,11 @@ def __init__(
self.checkpointer = checkpointer
self.server_model = server_model
self.parameter_exchanger = parameter_exchanger
self.server_name = server_name if server_name is not None else generate_hash()

if metrics_reporter is not None:
self.metrics_reporter = metrics_reporter
else:
self.metrics_reporter = MetricsReporter()
# Initialize reporters with server name information.
self.reports_manager = ReportsManager(reporters)
self.reports_manager.initialize(id=self.server_name)

def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
"""
Expand All @@ -77,7 +81,8 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float
Tuple[History, float]: The first element of the tuple is a History object containing the aggregated
metrics returned from the clients. Tuple also contains elapsed time in seconds for round.
"""
self.metrics_reporter.add_to_metrics({"type": "server", "fit_start": datetime.datetime.now()})

self.reports_manager.report({"host_type": "server", "fit_start": datetime.datetime.now()})

history = History()

Expand Down Expand Up @@ -117,11 +122,12 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float
# server_model, parameter_exchanger and checkpointer are not None
self._maybe_checkpoint(loss_aggregated=0.0, metrics_aggregated={}, server_round=1)

self.metrics_reporter.add_to_metrics(
self.reports_manager.report(
data={
"fit_end": datetime.datetime.now(),
"metrics_centralized": history.metrics_centralized,
"losses_centralized": history.losses_centralized,
"host_type": "server",
}
)

Expand Down

0 comments on commit 53431b7

Please sign in to comment.