Skip to content

Commit

Permalink
Merge branch 'main' into sa_rxrx1_research
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Jan 9, 2025
2 parents f80ca11 + 4f3cefd commit a2da4b1
Show file tree
Hide file tree
Showing 529 changed files with 9,237 additions and 7,668 deletions.
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.1
hooks:
- id: mypy
name: mypy
Expand All @@ -46,6 +46,14 @@ repos:
- id: nbqa-flake8
- id: nbqa-mypy

- repo: local
hooks:
- id: mypy legacy type check
name: mypy legacy type check
entry: python mypy_disallow_legacy_types.py
language: python
pass_filenames: true

ci:
autofix_commit_msg: |
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Expand Down
11 changes: 11 additions & 0 deletions CONTRIBUTING.MD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ The settings for `mypy` are in the `mypy.ini`, settings for `flake8` are contain

All of these checks and formatters are invoked by pre-commit hooks. These hooks are run remotely on GitHub. In order to ensure that your code conforms to these standards, and, therefore, passes the remote checks, you can install the pre-commit hooks to be run locally. This is done by running (with your environment active)

**Note**: We use the modern mypy types introduced in Python 3.10 and above. See some of the [documentation here](https://mypy.readthedocs.io/en/stable/builtin_types.html)

For example, this means that we're using `list[str], tuple[int, int], tuple[int, ...], dict[str, int], type[C]` as built-in types and `Iterable[int], Sequence[bool], Mapping[str, int], Callable[[...], ...]` from collections.abc (as now recommended by mypy).

We are also moving to the new Optional and Union specification style:
```python
Optional[typing_stuff] -> typing_stuff | None
Union[typing1, typing2] -> typing1 | typing2
Optional[Union[typing1, typing2]] -> typing1 | typing2 | None
```

```bash
pre-commit install
```
Expand Down
2 changes: 1 addition & 1 deletion examples/ae_examples/cvae_dim_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from the FL4Health directory. The following arguments must be present in the spe
* `n_server_rounds`: The number of rounds to run FL
* `checkpoint_path`: path to save the best server model
* `latent_dim`: size of the latent vector in the CVAE or VAE model
* `cvae_model_path`: path to the saved CVAE model for dimesionality reduction
* `cvae_model_path`: path to the saved CVAE model for dimensionality reduction

**NOTE**: Instead of using a global CVAE for all the clients, you can pass personalized CVAE models to each client, but make sure that these models are previously trained in an FL setting, and are not very different, otherwise, that can lead the dimensionality reduction to map the data samples into different latent spaces which might increase the heterogeneity.

Expand Down
8 changes: 4 additions & 4 deletions examples/ae_examples/cvae_dim_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -15,7 +15,7 @@
from fl4health.clients.basic_client import BasicClient
from fl4health.preprocessing.autoencoders.dim_reduction import CvaeFixedConditionProcessor
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
Expand All @@ -26,11 +26,11 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev
super().__init__(data_path, metrics, device)
self.condition = condition

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
cvae_model_path = Path(narrow_dict_type(config, "cvae_model_path", str))
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), transforms.Lambda(torch.flatten)])
# CvaeFixedConditionProcessor is added to the data transform pipeline to encode the data samples
train_loader, val_loader, _ = load_mnist_data(
data_dir=self.data_path,
Expand Down
17 changes: 10 additions & 7 deletions examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand All @@ -32,7 +33,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -47,7 +48,10 @@ def main(config: Dict[str, Any]) -> None:
model = MnistNet(int(config["latent_dim"]) * 2)
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -66,10 +70,9 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
10 changes: 5 additions & 5 deletions examples/ae_examples/cvae_examples/conv_cvae_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -17,15 +17,15 @@
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.metrics import Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler


def binary_class_condition_data_converter(
data: torch.Tensor, target: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Create a condition for each data sample.
# Condition is the binary representation of the target.
binary_representation = bin(int(target))[2:] # Convert to binary and remove the '0b' prefix
Expand Down Expand Up @@ -56,11 +56,11 @@ def setup_client(self, config: Config) -> None:
assert isinstance(self.model, ConditionalVae)
self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function()

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# To make sure pixels stay in the range [0.0, 1.0].
transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor()])
# To train an autoencoder-based model we need to set the data converter.
train_loader, val_loader, _ = load_mnist_data(
data_dir=self.data_path,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -23,7 +21,7 @@ def __init__(
self.fc_mu = nn.Linear(64, latent_dim)
self.fc_logvar = nn.Linear(64, latent_dim)

def forward(self, input: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = self.conv(input)
# Flatten the tensor
x = x.view(x.size(0), -1)
Expand Down
17 changes: 10 additions & 7 deletions examples/ae_examples/cvae_examples/conv_cvae_example/server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.ae_examples.cvae_examples.conv_cvae_example.models import ConvConditionalDecoder, ConvConditionalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
Expand All @@ -31,7 +32,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None:

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -67,10 +71,9 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
8 changes: 4 additions & 4 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -17,7 +17,7 @@
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.metrics import Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
Expand All @@ -44,13 +44,13 @@ def setup_client(self, config: Config) -> None:
assert isinstance(self.model, ConditionalVae)
self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function()

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# ToTensor transform is used to make sure pixels stay in the range [0.0, 1.0].
# Flattening the image data to match the input shape of the model.
flatten_transform = transforms.Lambda(lambda x: torch.flatten(x))
transform = transforms.Compose([transforms.ToTensor(), flatten_transform])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), flatten_transform])
train_loader, val_loader, _ = load_mnist_data(
data_dir=self.data_path,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -19,7 +17,7 @@ def __init__(
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)

def forward(self, input: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
input = torch.cat((input, condition), dim=-1)
x = F.relu(self.fc1(input))
x = F.relu(self.fc2(x))
Expand Down
17 changes: 10 additions & 7 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.ae_examples.cvae_examples.mlp_cvae_example.models import MnistConditionalDecoder, MnistConditionalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
Expand All @@ -31,7 +32,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None:

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -67,10 +71,9 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
7 changes: 3 additions & 4 deletions examples/ae_examples/fedprox_vae_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
from pathlib import Path
from typing import Tuple

import flwr as fl
import torch
Expand All @@ -17,16 +16,16 @@
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.sampler import DirichletLabelBasedSampler


class VaeFedProxClient(FedProxClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# Flattening the input images to use an MLP-based variational autoencoder.
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), transforms.Lambda(torch.flatten)])
# Create and pass the autoencoder data converter to the data loader.
self.autoencoder_converter = AutoEncoderDatasetConverter(condition=None)
train_loader, val_loader, _ = load_mnist_data(
Expand Down
2 changes: 1 addition & 1 deletion examples/ae_examples/fedprox_vae_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ batch_size: 32 # The batch size for client training

# FedProx variables
adaptive_proximal_weight: False # Whether to use adaptive proximal weight or not
proximal_weight : 0.1 # The proximal weight
initial_proximal_weight : 0.1 # The proximal weight

# Checkpointing
checkpoint_path: "examples/ae_examples/fedprox_vae_example"
Expand Down
4 changes: 1 addition & 3 deletions examples/ae_examples/fedprox_vae_example/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -18,7 +16,7 @@ def __init__(
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)

def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = F.relu(self.fc1(input))
x = F.relu(self.fc2(x))
return self.fc_mu(x), self.fc_logvar(x)
Expand Down
Loading

0 comments on commit a2da4b1

Please sign in to comment.