generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,756 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import argparse | ||
import os | ||
from logging import INFO | ||
from pathlib import Path | ||
from typing import Dict, Optional, Sequence, Tuple | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.logger import log | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
from torchvision import models | ||
|
||
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer | ||
from fl4health.checkpointing.client_module import ClientCheckpointModule | ||
from fl4health.clients.ditto_client import DittoClient | ||
from fl4health.utils.config import narrow_dict_type | ||
from fl4health.utils.losses import LossMeterType | ||
from fl4health.utils.metrics import Accuracy, Metric | ||
from fl4health.utils.random import set_all_random_seeds | ||
from research.rxrx1.data.data_utils import load_rxrx1_data, load_rxrx1_test_data | ||
|
||
|
||
class Rxrx1DittoClient(DittoClient): | ||
def __init__( | ||
self, | ||
data_path: Path, | ||
metrics: Sequence[Metric], | ||
device: torch.device, | ||
client_number: int, | ||
learning_rate: float, | ||
loss_meter_type: LossMeterType = LossMeterType.AVERAGE, | ||
checkpointer: Optional[ClientCheckpointModule] = None, | ||
) -> None: | ||
super().__init__( | ||
data_path=data_path, | ||
metrics=metrics, | ||
device=device, | ||
loss_meter_type=loss_meter_type, | ||
checkpointer=checkpointer, | ||
) | ||
self.client_number = client_number | ||
self.learning_rate: float = learning_rate | ||
|
||
log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") | ||
|
||
def setup_client(self, config: Config) -> None: | ||
# Check if the client number is within the range of the total number of clients | ||
num_clients = narrow_dict_type(config, "n_clients", int) | ||
assert 0 <= self.client_number < num_clients | ||
super().setup_client(config) | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
batch_size = narrow_dict_type(config, "batch_size", int) | ||
train_loader, val_loader, _ = load_rxrx1_data( | ||
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number | ||
) | ||
|
||
return train_loader, val_loader | ||
|
||
def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: | ||
batch_size = narrow_dict_type(config, "batch_size", int) | ||
test_loader, _ = load_rxrx1_test_data( | ||
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size | ||
) | ||
|
||
return test_loader | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: | ||
# Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized | ||
# Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer | ||
global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) | ||
local_optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9) | ||
return {"global": global_optimizer, "local": local_optimizer} | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
return models.resnet18(pretrained=True).to(self.device) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="FL Client Main") | ||
parser.add_argument( | ||
"--artifact_dir", | ||
action="store", | ||
type=str, | ||
help="Path to save client artifacts such as logs and model checkpoints", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--dataset_dir", | ||
action="store", | ||
type=str, | ||
help="Path to the preprocessed Rxrx1 Dataset", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--run_name", | ||
action="store", | ||
help="Name of the run, model checkpoints will be saved under a subfolder with this name", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--server_address", | ||
action="store", | ||
type=str, | ||
help="Server Address for the clients to communicate with the server through", | ||
default="0.0.0.0:8080", | ||
) | ||
parser.add_argument( | ||
"--client_number", | ||
action="store", | ||
type=int, | ||
help="Number of the client for dataset loading (should be 0-3 for Rxrx1)", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
action="store", | ||
type=int, | ||
help="Seed for the random number generators across python, torch, and numpy", | ||
required=False, | ||
) | ||
args = parser.parse_args() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
log(INFO, f"Device to be used: {device}") | ||
log(INFO, f"Server Address: {args.server_address}") | ||
log(INFO, f"Learning Rate: {args.learning_rate}") | ||
|
||
# Set the random seed for reproducibility | ||
set_all_random_seeds(args.seed) | ||
|
||
# Adding extensive checkpointing for the client | ||
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) | ||
pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" | ||
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" | ||
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" | ||
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" | ||
checkpointer = ClientCheckpointModule( | ||
pre_aggregation=[ | ||
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), | ||
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), | ||
], | ||
post_aggregation=[ | ||
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), | ||
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), | ||
], | ||
) | ||
|
||
data_path = Path(args.dataset_dir) | ||
client = Rxrx1DittoClient( | ||
data_path=data_path, | ||
metrics=[Accuracy("accuracy")], | ||
device=device, | ||
client_number=args.client_number, | ||
learning_rate=args.learning_rate, | ||
checkpointer=checkpointer, | ||
) | ||
|
||
fl.client.start_client(server_address=args.server_address, client=client.to_client()) | ||
# Shutdown the client gracefully | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Parameters that describe server | ||
n_server_rounds: 10 # The number of rounds to run FL | ||
|
||
# Parameters that describe clients | ||
n_clients: 4 # The number of clients in the FL experiment | ||
local_epochs: 5 # The number of epochs to complete for client | ||
batch_size: 32 # The batch size for client training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH --nodes=1 | ||
#SBATCH --ntasks=1 | ||
#SBATCH --ntasks-per-node=1 | ||
#SBATCH --cpus-per-task=1 | ||
#SBATCH --gres=gpu:1 | ||
#SBATCH --mem=32G | ||
#SBATCH --partition=a40 | ||
#SBATCH --qos=m2 | ||
#SBATCH --job-name=fl_five_fold_exp | ||
#SBATCH --output=%j_%x.out | ||
#SBATCH --error=%j_%x.err | ||
#SBATCH --time=4:00:00 | ||
|
||
############################################### | ||
# Usage: | ||
# | ||
# sbatch research/rxrx1/ditto/run_fold_experiment.slrm \ | ||
# path_to_config.yaml \ | ||
# path_to_folder_for_artifacts/ \ | ||
# path_to_folder_for_dataset/ \ | ||
# path_to_desired_venv/ \ | ||
# client_side_learning_rate_value \ | ||
# lambda value \ | ||
# server_address | ||
# | ||
# Example: | ||
# sbatch research/rxrx1/ditto/run_fold_experiment.slrm \ | ||
# research/rxrx1/ditto/config.yaml \ | ||
# research/rxrx1/ditto/hp_results/ \ | ||
# /datasets/rxrx1 \ | ||
# /h/demerson/vector_repositories/fl4health_env/ \ | ||
# 0.0001 \ | ||
# 0.01 \ | ||
# 0.0.0.0:8080 | ||
# | ||
# Notes: | ||
# 1) The sbatch command above should be run from the top level directory of the repository. | ||
# 2) This example runs ditto. As such the data paths and python launch commands are hardcoded. If you want to change | ||
# the example you run, you need to explicitly modify the code below. | ||
# 3) The logging directories need to ALREADY EXIST. The script does not create them. | ||
############################################### | ||
|
||
# Note: | ||
# ntasks: Total number of processes to use across world | ||
# ntasks-per-node: How many processes each node should create | ||
|
||
# Set NCCL options | ||
# export NCCL_DEBUG=INFO | ||
# NCCL backend to communicate between GPU workers is not provided in vector's cluster. | ||
# Disable this option in slurm. | ||
export NCCL_IB_DISABLE=1 | ||
|
||
if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ | ||
[[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then | ||
echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" | ||
export NCCL_SOCKET_IFNAME=bond0 | ||
fi | ||
|
||
|
||
export CUBLAS_WORKSPACE_CONFIG=:4096:8 | ||
# Process Inputs | ||
|
||
SERVER_CONFIG_PATH=$1 | ||
ARTIFACT_DIR=$2 | ||
DATASET_DIR=$3 | ||
VENV_PATH=$4 | ||
CLIENT_LR=$5 | ||
SERVER_ADDRESS=$6 | ||
|
||
# Create the artifact directory | ||
mkdir "${ARTIFACT_DIR}" | ||
|
||
RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) | ||
SEEDS=( 2021 2022 2023 2024 2025 ) | ||
|
||
echo "Python Venv Path: ${VENV_PATH}" | ||
|
||
echo "World size: ${SLURM_NTASKS}" | ||
echo "Number of nodes: ${SLURM_NNODES}" | ||
NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) | ||
echo "GPUs per node: ${NUM_GPUs}" | ||
|
||
# Source the environment | ||
source ${VENV_PATH}bin/activate | ||
echo "Active Environment:" | ||
which python | ||
|
||
for ((i=0; i<${#RUN_NAMES[@]}; i++)); | ||
do | ||
RUN_NAME="${RUN_NAMES[i]}" | ||
SEED="${SEEDS[i]}" | ||
# create the run directory | ||
RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" | ||
echo "Starting Run and logging artifcats at ${RUN_DIR}" | ||
if [ -d "${RUN_DIR}" ] | ||
then | ||
# Directory already exists, we check if the done.out file exists | ||
if [ -f "${RUN_DIR}done.out" ] | ||
then | ||
# Done file already exists so we skip this run | ||
echo "Run already completed. Skipping Run." | ||
continue | ||
else | ||
# Done file doesn't exists (assume pre-emption happened) | ||
# Delete the partially finished contents and start over | ||
echo "Run did not finished correctly. Re-running." | ||
rm -r "${RUN_DIR}" | ||
mkdir "${RUN_DIR}" | ||
fi | ||
else | ||
# Directory doesn't exist yet, so we create it. | ||
echo "Run directory does not exist. Creating it." | ||
mkdir "${RUN_DIR}" | ||
fi | ||
|
||
SERVER_OUTPUT_FILE="${RUN_DIR}server.out" | ||
|
||
# Start the server, divert the outputs to a server file | ||
|
||
echo "Server logging at: ${SERVER_OUTPUT_FILE}" | ||
echo "Launching Server" | ||
|
||
nohup python -m research.rxrx1.ditto.server \ | ||
--config_path ${SERVER_CONFIG_PATH} \ | ||
--server_address ${SERVER_ADDRESS} \ | ||
--seed ${SEED} \ | ||
--lam ${LAM_VALUE} \ | ||
> ${SERVER_OUTPUT_FILE} 2>&1 & | ||
|
||
# Sleep for 20 seconds to allow the server to come up. | ||
sleep 20 | ||
|
||
# Start n number of clients and divert the outputs to their own files | ||
n_clients=4 | ||
for (( c=0; c<${n_clients}; c++ )) | ||
do | ||
CLIENT_NAME="client_${c}" | ||
echo "Launching ${CLIENT_NAME}" | ||
|
||
CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" | ||
echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" | ||
nohup python -m research.rxrx1.ditto.client \ | ||
--artifact_dir ${ARTIFACT_DIR} \ | ||
--dataset_dir ${DATASET_DIR} \ | ||
--run_name ${RUN_NAME} \ | ||
--client_number ${c} \ | ||
--learning_rate ${CLIENT_LR} \ | ||
--server_address ${SERVER_ADDRESS} \ | ||
--seed ${SEED} \ | ||
> ${CLIENT_LOG_PATH} 2>&1 & | ||
done | ||
|
||
echo "FL Processes Running" | ||
|
||
wait | ||
|
||
# Create a file that verifies that the Run concluded properly | ||
touch "${RUN_DIR}done.out" | ||
echo "Finished FL Processes" | ||
|
||
done |
Oops, something went wrong.