Skip to content

Commit

Permalink
Update typing with new changes in main
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Jan 9, 2025
1 parent a2da4b1 commit 3b56645
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 96 deletions.
5 changes: 2 additions & 3 deletions research/rxrx1/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import defaultdict
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -32,7 +31,7 @@ def label_frequency(dataset: Rxrx1Dataset | Subset) -> None:


def create_splits(
dataset: Rxrx1Dataset, seed: Optional[int] = None, train_fraction: float = 0.8
dataset: Rxrx1Dataset, seed: int | None = None, train_fraction: float = 0.8
) -> tuple[Subset, Subset]:
"""
Splits the dataset into training and validation sets.
Expand Down Expand Up @@ -70,7 +69,7 @@ def create_splits(


def load_rxrx1_data(
data_path: Path, client_num: int, batch_size: int, seed: Optional[int] = None, train_val_split: float = 0.8
data_path: Path, client_num: int, batch_size: int, seed: int | None = None, train_val_split: float = 0.8
) -> tuple[DataLoader, DataLoader, dict[str, int]]:

# Read the CSV file
Expand Down
4 changes: 2 additions & 2 deletions research/rxrx1/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from pathlib import Path
from typing import Callable, Optional

import pandas as pd
import torch
Expand All @@ -10,7 +10,7 @@


class Rxrx1Dataset(Dataset):
def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transform: Optional[Callable] = None):
def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transform: Callable | None = None):
"""
Args:
metadata (DataFrame): A DataFrame containing image metadata.
Expand Down
35 changes: 21 additions & 14 deletions research/rxrx1/ditto/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import os
from collections.abc import Sequence
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -14,9 +14,10 @@
from torch.utils.data import DataLoader
from torchvision import models

from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.ditto_client import DittoClient
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric
Expand All @@ -33,14 +34,20 @@ def __init__(
client_number: int,
learning_rate: float,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None,
reporters: Sequence[BaseReporter] | None = None,
progress_bar: bool = False,
client_name: str | None = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
)
self.client_number = client_number
self.learning_rate: float = learning_rate
Expand All @@ -53,15 +60,15 @@ def setup_client(self, config: Config) -> None:
assert 0 <= self.client_number < num_clients
super().setup_client(config)

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_rxrx1_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number
)

return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_test_data_loader(self, config: Config) -> DataLoader | None:
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_rxrx1_test_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size
Expand All @@ -72,7 +79,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
# Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized
# Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer
global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9)
Expand Down Expand Up @@ -145,14 +152,14 @@ def get_model(self, config: Config) -> nn.Module:
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl"
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl"
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl"
checkpointer = ClientCheckpointModule(
checkpoint_and_state_module = ClientCheckpointAndStateModule(
pre_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
],
post_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
],
)

Expand All @@ -163,7 +170,7 @@ def get_model(self, config: Config) -> nn.Module:
device=device,
client_number=args.client_number,
learning_rate=args.learning_rate,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.client.start_client(server_address=args.server_address, client=client.to_client())
Expand Down
4 changes: 2 additions & 2 deletions research/rxrx1/ditto/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from functools import partial
from logging import INFO
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.logger import log
Expand Down Expand Up @@ -33,7 +33,7 @@ def fit_config(
}


def main(config: Dict[str, Any], server_address: str, lam: float) -> None:
def main(config: dict[str, Any], server_address: str, lam: float) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
35 changes: 21 additions & 14 deletions research/rxrx1/ditto_deep_mmd/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import os
from collections import OrderedDict
from collections.abc import Sequence
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -15,9 +15,10 @@
from torch.utils.data import DataLoader
from torchvision import models

from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric
Expand All @@ -41,17 +42,23 @@ def __init__(
client_number: int,
learning_rate: float,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None,
reporters: Sequence[BaseReporter] | None = None,
progress_bar: bool = False,
client_name: str | None = None,
deep_mmd_loss_weight: float = 10,
deep_mmd_loss_depth: int = 1,
checkpointer: Optional[ClientCheckpointModule] = None,
) -> None:
feature_extraction_layers_with_size = OrderedDict(list(BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :])
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
deep_mmd_loss_weight=deep_mmd_loss_weight,
feature_extraction_layers_with_size=feature_extraction_layers_with_size,
)
Expand All @@ -66,15 +73,15 @@ def setup_client(self, config: Config) -> None:
assert 0 <= self.client_number < num_clients
super().setup_client(config)

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_rxrx1_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number
)

return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_test_data_loader(self, config: Config) -> DataLoader | None:
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_rxrx1_test_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size
Expand All @@ -85,7 +92,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
# Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized
# Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer
global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9)
Expand Down Expand Up @@ -175,14 +182,14 @@ def get_model(self, config: Config) -> nn.Module:
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl"
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl"
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl"
checkpointer = ClientCheckpointModule(
checkpoint_and_state_module = ClientCheckpointAndStateModule(
pre_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
],
post_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
],
)

Expand All @@ -193,7 +200,7 @@ def get_model(self, config: Config) -> nn.Module:
device=device,
client_number=args.client_number,
learning_rate=args.learning_rate,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
deep_mmd_loss_depth=args.deep_mmd_loss_depth,
deep_mmd_loss_weight=args.mu,
)
Expand Down
4 changes: 2 additions & 2 deletions research/rxrx1/ditto_deep_mmd/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from functools import partial
from logging import INFO
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.logger import log
Expand Down Expand Up @@ -33,7 +33,7 @@ def fit_config(
}


def main(config: Dict[str, Any], server_address: str, lam: float) -> None:
def main(config: dict[str, Any], server_address: str, lam: float) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
35 changes: 21 additions & 14 deletions research/rxrx1/ditto_mkmmd/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import os
from collections.abc import Sequence
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -14,9 +14,10 @@
from torch.utils.data import DataLoader
from torchvision import models

from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.mkmmd_clients.ditto_mkmmd_client import DittoMkMmdClient
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric
Expand All @@ -39,14 +40,20 @@ def __init__(
feature_l2_norm_weight: float = 1,
mkmmd_loss_depth: int = 1,
beta_global_update_interval: int = 20,
checkpointer: Optional[ClientCheckpointModule] = None,
checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None,
reporters: Sequence[BaseReporter] | None = None,
progress_bar: bool = False,
client_name: str | None = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
mkmmd_loss_weight=mkmmd_loss_weight,
feature_extraction_layers=BASELINE_LAYERS[-1 * mkmmd_loss_depth :],
feature_l2_norm_weight=feature_l2_norm_weight,
Expand All @@ -66,15 +73,15 @@ def setup_client(self, config: Config) -> None:
assert 0 <= self.client_number < num_clients
super().setup_client(config)

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_rxrx1_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number
)

return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_test_data_loader(self, config: Config) -> DataLoader | None:
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_rxrx1_test_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size
Expand All @@ -85,7 +92,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
# Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized
# Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer
global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9)
Expand Down Expand Up @@ -192,14 +199,14 @@ def get_model(self, config: Config) -> nn.Module:
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl"
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl"
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl"
checkpointer = ClientCheckpointModule(
checkpoint_and_state_module = ClientCheckpointAndStateModule(
pre_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
],
post_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
],
)

Expand All @@ -210,7 +217,7 @@ def get_model(self, config: Config) -> nn.Module:
device=device,
client_number=args.client_number,
learning_rate=args.learning_rate,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
feature_l2_norm_weight=args.l2,
mkmmd_loss_depth=args.mkmmd_loss_depth,
mkmmd_loss_weight=args.mu,
Expand Down
Loading

0 comments on commit 3b56645

Please sign in to comment.