Skip to content

Commit

Permalink
Add rxrx1 ditto experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Jan 8, 2025
1 parent 3d03729 commit d306951
Show file tree
Hide file tree
Showing 19 changed files with 1,756 additions and 0 deletions.
Empty file.
171 changes: 171 additions & 0 deletions research/rxrx1/ditto/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import argparse
import os
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.logger import log
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision import models

from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.clients.ditto_client import DittoClient
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from research.rxrx1.data.data_utils import load_rxrx1_data, load_rxrx1_test_data


class Rxrx1DittoClient(DittoClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
client_number: int,
learning_rate: float,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
)
self.client_number = client_number
self.learning_rate: float = learning_rate

log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}")

def setup_client(self, config: Config) -> None:
# Check if the client number is within the range of the total number of clients
num_clients = narrow_dict_type(config, "n_clients", int)
assert 0 <= self.client_number < num_clients
super().setup_client(config)

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_rxrx1_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number
)

return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_rxrx1_test_data(
data_path=self.data_path, client_num=self.client_number, batch_size=batch_size
)

return test_loader

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
# Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized
# Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer
global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9)
local_optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
return {"global": global_optimizer, "local": local_optimizer}

def get_model(self, config: Config) -> nn.Module:
return models.resnet18(pretrained=True).to(self.device)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument(
"--artifact_dir",
action="store",
type=str,
help="Path to save client artifacts such as logs and model checkpoints",
required=True,
)
parser.add_argument(
"--dataset_dir",
action="store",
type=str,
help="Path to the preprocessed Rxrx1 Dataset",
required=True,
)
parser.add_argument(
"--run_name",
action="store",
help="Name of the run, model checkpoints will be saved under a subfolder with this name",
required=True,
)
parser.add_argument(
"--server_address",
action="store",
type=str,
help="Server Address for the clients to communicate with the server through",
default="0.0.0.0:8080",
)
parser.add_argument(
"--client_number",
action="store",
type=int,
help="Number of the client for dataset loading (should be 0-3 for Rxrx1)",
required=True,
)
parser.add_argument(
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1
)
parser.add_argument(
"--seed",
action="store",
type=int,
help="Seed for the random number generators across python, torch, and numpy",
required=False,
)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {device}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

# Adding extensive checkpointing for the client
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl"
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl"
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl"
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl"
checkpointer = ClientCheckpointModule(
pre_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
],
post_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
],
)

data_path = Path(args.dataset_dir)
client = Rxrx1DittoClient(
data_path=data_path,
metrics=[Accuracy("accuracy")],
device=device,
client_number=args.client_number,
learning_rate=args.learning_rate,
checkpointer=checkpointer,
)

fl.client.start_client(server_address=args.server_address, client=client.to_client())
# Shutdown the client gracefully
client.shutdown()
7 changes: 7 additions & 0 deletions research/rxrx1/ditto/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Parameters that describe server
n_server_rounds: 10 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 4 # The number of clients in the FL experiment
local_epochs: 5 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training
163 changes: 163 additions & 0 deletions research/rxrx1/ditto/run_fold_experiment.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/bin/bash

#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --mem=32G
#SBATCH --partition=a40
#SBATCH --qos=m2
#SBATCH --job-name=fl_five_fold_exp
#SBATCH --output=%j_%x.out
#SBATCH --error=%j_%x.err
#SBATCH --time=4:00:00

###############################################
# Usage:
#
# sbatch research/rxrx1/ditto/run_fold_experiment.slrm \
# path_to_config.yaml \
# path_to_folder_for_artifacts/ \
# path_to_folder_for_dataset/ \
# path_to_desired_venv/ \
# client_side_learning_rate_value \
# lambda value \
# server_address
#
# Example:
# sbatch research/rxrx1/ditto/run_fold_experiment.slrm \
# research/rxrx1/ditto/config.yaml \
# research/rxrx1/ditto/hp_results/ \
# /datasets/rxrx1 \
# /h/demerson/vector_repositories/fl4health_env/ \
# 0.0001 \
# 0.01 \
# 0.0.0.0:8080
#
# Notes:
# 1) The sbatch command above should be run from the top level directory of the repository.
# 2) This example runs ditto. As such the data paths and python launch commands are hardcoded. If you want to change
# the example you run, you need to explicitly modify the code below.
# 3) The logging directories need to ALREADY EXIST. The script does not create them.
###############################################

# Note:
# ntasks: Total number of processes to use across world
# ntasks-per-node: How many processes each node should create

# Set NCCL options
# export NCCL_DEBUG=INFO
# NCCL backend to communicate between GPU workers is not provided in vector's cluster.
# Disable this option in slurm.
export NCCL_IB_DISABLE=1

if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \
[[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then
echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}"
export NCCL_SOCKET_IFNAME=bond0
fi


export CUBLAS_WORKSPACE_CONFIG=:4096:8
# Process Inputs

SERVER_CONFIG_PATH=$1
ARTIFACT_DIR=$2
DATASET_DIR=$3
VENV_PATH=$4
CLIENT_LR=$5
SERVER_ADDRESS=$6

# Create the artifact directory
mkdir "${ARTIFACT_DIR}"

RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" )
SEEDS=( 2021 2022 2023 2024 2025 )

echo "Python Venv Path: ${VENV_PATH}"

echo "World size: ${SLURM_NTASKS}"
echo "Number of nodes: ${SLURM_NNODES}"
NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
echo "GPUs per node: ${NUM_GPUs}"

# Source the environment
source ${VENV_PATH}bin/activate
echo "Active Environment:"
which python

for ((i=0; i<${#RUN_NAMES[@]}; i++));
do
RUN_NAME="${RUN_NAMES[i]}"
SEED="${SEEDS[i]}"
# create the run directory
RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/"
echo "Starting Run and logging artifcats at ${RUN_DIR}"
if [ -d "${RUN_DIR}" ]
then
# Directory already exists, we check if the done.out file exists
if [ -f "${RUN_DIR}done.out" ]
then
# Done file already exists so we skip this run
echo "Run already completed. Skipping Run."
continue
else
# Done file doesn't exists (assume pre-emption happened)
# Delete the partially finished contents and start over
echo "Run did not finished correctly. Re-running."
rm -r "${RUN_DIR}"
mkdir "${RUN_DIR}"
fi
else
# Directory doesn't exist yet, so we create it.
echo "Run directory does not exist. Creating it."
mkdir "${RUN_DIR}"
fi

SERVER_OUTPUT_FILE="${RUN_DIR}server.out"

# Start the server, divert the outputs to a server file

echo "Server logging at: ${SERVER_OUTPUT_FILE}"
echo "Launching Server"

nohup python -m research.rxrx1.ditto.server \
--config_path ${SERVER_CONFIG_PATH} \
--server_address ${SERVER_ADDRESS} \
--seed ${SEED} \
--lam ${LAM_VALUE} \
> ${SERVER_OUTPUT_FILE} 2>&1 &

# Sleep for 20 seconds to allow the server to come up.
sleep 20

# Start n number of clients and divert the outputs to their own files
n_clients=4
for (( c=0; c<${n_clients}; c++ ))
do
CLIENT_NAME="client_${c}"
echo "Launching ${CLIENT_NAME}"

CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out"
echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}"
nohup python -m research.rxrx1.ditto.client \
--artifact_dir ${ARTIFACT_DIR} \
--dataset_dir ${DATASET_DIR} \
--run_name ${RUN_NAME} \
--client_number ${c} \
--learning_rate ${CLIENT_LR} \
--server_address ${SERVER_ADDRESS} \
--seed ${SEED} \
> ${CLIENT_LOG_PATH} 2>&1 &
done

echo "FL Processes Running"

wait

# Create a file that verifies that the Run concluded properly
touch "${RUN_DIR}done.out"
echo "Finished FL Processes"

done
Loading

0 comments on commit d306951

Please sign in to comment.