diff --git a/examples/ae_examples/cvae_dim_example/client.py b/examples/ae_examples/cvae_dim_example/client.py index 5e5577849..c7d21fff5 100644 --- a/examples/ae_examples/cvae_dim_example/client.py +++ b/examples/ae_examples/cvae_dim_example/client.py @@ -22,8 +22,8 @@ class CvaeDimClient(BasicClient): - def __init__(self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.device, condition: torch.Tensor): - super().__init__(data_path, metrics, DEVICE) + def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.device, condition: torch.Tensor): + super().__init__(data_path, metrics, device) self.condition = condition def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: @@ -64,11 +64,11 @@ def get_model(self, config: Config) -> nn.Module: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) set_all_random_seeds(42) # Creating the condition vector used for training this CVAE. condition_vector = torch.nn.functional.one_hot(torch.tensor(args.condition), num_classes=args.num_conditions) - client = CvaeDimClient(data_path, [Accuracy("accuracy")], DEVICE, condition_vector) + client = CvaeDimClient(data_path, [Accuracy("accuracy")], device, condition_vector) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/client.py b/examples/ae_examples/cvae_examples/conv_cvae_example/client.py index ec9dc85c0..aa2aed303 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/client.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/client.py @@ -36,8 +36,8 @@ def binary_class_condition_data_converter( class CondConvAutoEncoderClient(BasicClient): - def __init__(self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.device) -> None: - super().__init__(data_path, metrics, DEVICE) + def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.device) -> None: + super().__init__(data_path, metrics, device) # To train an autoencoder-based model we need to define a data converter that prepares the data # for self-supervised learning, concatenates the inputs and condition (packing) to let the data # fit into the training pipeline, and unpacks the input from condition for the model inference. @@ -93,8 +93,8 @@ def get_model(self, config: Config) -> nn.Module: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() set_all_random_seeds(42) - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CondConvAutoEncoderClient(data_path=data_path, metrics=[], DEVICE=DEVICE) + client = CondConvAutoEncoderClient(data_path=data_path, metrics=[], device=device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py index 5840d9e28..866197d54 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py @@ -25,9 +25,9 @@ class CondAutoEncoderClient(BasicClient): def __init__( - self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.device, condition: torch.Tensor + self, data_path: Path, metrics: Sequence[Metric], device: torch.device, condition: torch.Tensor ) -> None: - super().__init__(data_path, metrics, DEVICE) + super().__init__(data_path, metrics, device) # In this example, condition is based on client ID. self.condition_vector = condition # To train an autoencoder-based model we need to define a data converter that prepares the data @@ -96,13 +96,13 @@ def get_model(self, config: Config) -> nn.Module: ) args = parser.parse_args() set_all_random_seeds(42) - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) # Create the condition vector. This creation needs to be "consistent" across clients. # In this example, condition is based on client ID. # Client should decide how they want to create their condition vector. # Here we use simple one_hot_encoding but it can be any vector. condition_vector = torch.nn.functional.one_hot(torch.tensor(args.condition), num_classes=args.num_conditions) - client = CondAutoEncoderClient(data_path, [], DEVICE, condition_vector) + client = CondAutoEncoderClient(data_path, [], device, condition_vector) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/ae_examples/fedprox_vae_example/client.py b/examples/ae_examples/fedprox_vae_example/client.py index 9567fff37..69a5e3bc3 100644 --- a/examples/ae_examples/fedprox_vae_example/client.py +++ b/examples/ae_examples/fedprox_vae_example/client.py @@ -60,8 +60,8 @@ def get_model(self, config: Config) -> nn.Module: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = VaeFedProxClient(data_path, [], DEVICE) + client = VaeFedProxClient(data_path, [], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/apfl_example/client.py b/examples/apfl_example/client.py index 2a5eebdd8..75818a3d9 100644 --- a/examples/apfl_example/client.py +++ b/examples/apfl_example/client.py @@ -53,12 +53,12 @@ def get_criterion(self, config: Config) -> _Loss: args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + client = MnistApflClient(data_path, [Accuracy()], device, reporters=[JsonReporter()]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() # This will tell the JsonReporter to dump data diff --git a/examples/basic_example/client.py b/examples/basic_example/client.py index 61c801df8..5d44e5db8 100644 --- a/examples/basic_example/client.py +++ b/examples/basic_example/client.py @@ -43,8 +43,8 @@ def get_model(self, config: Config) -> nn.Module: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE) + client = CifarClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/ditto_example/client.py b/examples/ditto_example/client.py index 5c22b32db..b91c572d3 100644 --- a/examples/ditto_example/client.py +++ b/examples/ditto_example/client.py @@ -61,15 +61,15 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistDittoClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + client = MnistDittoClient(data_path, [Accuracy()], device, reporters=[JsonReporter()]) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully diff --git a/examples/docker_basic_example/fl_client/client.py b/examples/docker_basic_example/fl_client/client.py index f2bbe7a93..5bf59948c 100644 --- a/examples/docker_basic_example/fl_client/client.py +++ b/examples/docker_basic_example/fl_client/client.py @@ -39,7 +39,7 @@ def setup_client(self, config: Config) -> None: args = parser.parse_args() # Load model and data - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE) + client = CifarClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="fl_server:8080", client=client.to_client()) diff --git a/examples/dp_fed_examples/client_level_dp/client.py b/examples/dp_fed_examples/client_level_dp/client.py index a6935a95e..090ad6013 100644 --- a/examples/dp_fed_examples/client_level_dp/client.py +++ b/examples/dp_fed_examples/client_level_dp/client.py @@ -40,7 +40,7 @@ def get_criterion(self, config: Config) -> _Loss: # Load model and data data_path = Path(args.dataset_path) - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + client = CifarClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/dp_fed_examples/client_level_dp_weighted/client.py b/examples/dp_fed_examples/client_level_dp_weighted/client.py index 7db76386d..8d2130894 100644 --- a/examples/dp_fed_examples/client_level_dp_weighted/client.py +++ b/examples/dp_fed_examples/client_level_dp_weighted/client.py @@ -40,8 +40,8 @@ def get_criterion(self, config: Config) -> _Loss: args = parser.parse_args() # Load model and data - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = HospitalClient(data_path, [Accuracy("accuracy")], DEVICE) + client = HospitalClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/dp_fed_examples/instance_level_dp/client.py b/examples/dp_fed_examples/instance_level_dp/client.py index 9d82dcb71..f66450125 100644 --- a/examples/dp_fed_examples/instance_level_dp/client.py +++ b/examples/dp_fed_examples/instance_level_dp/client.py @@ -52,8 +52,8 @@ def get_criterion(self, config: Config) -> _Loss: # Load model and data data_path = Path(args.dataset_path) - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE, checkpointer=checkpointer) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + client = CifarClient(data_path, [Accuracy("accuracy")], device, checkpointer=checkpointer) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/dp_scaffold_example/client.py b/examples/dp_scaffold_example/client.py index 34411faaf..f9e6f072d 100644 --- a/examples/dp_scaffold_example/client.py +++ b/examples/dp_scaffold_example/client.py @@ -41,10 +41,10 @@ def get_criterion(self, config: Config) -> _Loss: args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = MnistDPScaffoldClient(data_path=data_path, metrics=[Accuracy()], device=DEVICE) + client = MnistDPScaffoldClient(data_path=data_path, metrics=[Accuracy()], device=device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/dynamic_layer_exchange_example/client.py b/examples/dynamic_layer_exchange_example/client.py index 73b8fad7d..9f2ff697d 100644 --- a/examples/dynamic_layer_exchange_example/client.py +++ b/examples/dynamic_layer_exchange_example/client.py @@ -67,8 +67,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CifarDynamicLayerClient(data_path, [Accuracy("accuracy")], DEVICE, store_initial_model=True) + client = CifarDynamicLayerClient(data_path, [Accuracy("accuracy")], device, store_initial_model=True) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/ensemble_example/client.py b/examples/ensemble_example/client.py index ad28d8104..def6af758 100644 --- a/examples/ensemble_example/client.py +++ b/examples/ensemble_example/client.py @@ -52,8 +52,8 @@ def get_criterion(self, config: Config) -> _Loss: args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = MnistEnsembleClient(data_path, [Accuracy()], DEVICE) + client = MnistEnsembleClient(data_path, [Accuracy()], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) diff --git a/examples/feature_alignment_example/client.py b/examples/feature_alignment_example/client.py index d0bbebc28..cd2e88b85 100644 --- a/examples/feature_alignment_example/client.py +++ b/examples/feature_alignment_example/client.py @@ -90,15 +90,15 @@ def get_data_frame(self, config: Config) -> pd.DataFrame: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # ham_id is the id column and LOSgroupNum is the target column. - client = Mimic3TabularDataClient(data_path, [Accuracy("accuracy")], DEVICE, "hadm_id", ["LOSgroupNum"]) + client = Mimic3TabularDataClient(data_path, [Accuracy("accuracy")], device, "hadm_id", ["LOSgroupNum"]) # This call demonstrates how the user may specify a particular sklearn pipeline for a specific feature. client.preset_specific_pipeline("NumNotes", MaxAbsScaler()) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/examples/fedbn_example/client.py b/examples/fedbn_example/client.py index 815581559..30658a1a8 100644 --- a/examples/fedbn_example/client.py +++ b/examples/fedbn_example/client.py @@ -88,15 +88,15 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") if args.dataset_name in ["Barcelona", "Rosendahl", "Vienna", "UFES", "Canada"]: - client: BasicClient = SkinCancerFedBNClient(data_path, [Accuracy()], DEVICE, args.dataset_name) + client: BasicClient = SkinCancerFedBNClient(data_path, [Accuracy()], device, args.dataset_name) elif args.dataset_name == "mnist": - client = MnistFedBNClient(data_path, [Accuracy()], DEVICE) + client = MnistFedBNClient(data_path, [Accuracy()], device) else: raise ValueError( "Unsupported dataset name. Please choose from 'Barcelona', 'Rosendahl', \ diff --git a/examples/feddg_ga_example/client.py b/examples/feddg_ga_example/client.py index 81023fd69..1a84fbee1 100644 --- a/examples/feddg_ga_example/client.py +++ b/examples/feddg_ga_example/client.py @@ -53,12 +53,12 @@ def get_criterion(self, config: Config) -> _Loss: args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + client = MnistApflClient(data_path, [Accuracy()], device, reporters=[JsonReporter()]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/federated_eval_example/client.py b/examples/federated_eval_example/client.py index 53bf742b1..a3f1190c3 100644 --- a/examples/federated_eval_example/client.py +++ b/examples/federated_eval_example/client.py @@ -63,12 +63,12 @@ def get_criterion(self, config: Config) -> _Loss: data_path = Path(args.dataset_path) client_checkpoint_path = Path(args.checkpoint_path) if args.checkpoint_path else None - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") client = CifarClient( data_path=data_path, metrics=[Accuracy("accuracy")], - device=DEVICE, + device=device, model_checkpoint_path=client_checkpoint_path, ) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) diff --git a/examples/fedopt_example/client.py b/examples/fedopt_example/client.py index 6fff4dd65..a96b173c8 100644 --- a/examples/fedopt_example/client.py +++ b/examples/fedopt_example/client.py @@ -95,7 +95,7 @@ def predict( args = parser.parse_args() # Load model and data - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = NewsClassifierClient(data_path, [CompoundMetric("Compound Metric")], DEVICE) + client = NewsClassifierClient(data_path, [CompoundMetric("Compound Metric")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) diff --git a/examples/fedpca_examples/dim_reduction/client.py b/examples/fedpca_examples/dim_reduction/client.py index e289331e6..f049e1777 100644 --- a/examples/fedpca_examples/dim_reduction/client.py +++ b/examples/fedpca_examples/dim_reduction/client.py @@ -66,7 +66,7 @@ def get_model(self, config: Config) -> nn.Module: parser.add_argument("--seed", action="store", type=int, help="Random seed for this client.") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) seed = args.seed @@ -74,6 +74,6 @@ def get_model(self, config: Config) -> nn.Module: # the data used in the perform_pca example, then both examples # should use the same random seed. set_all_random_seeds(seed) - client = MnistFedPcaClient(data_path, [Accuracy("accuracy")], DEVICE) + client = MnistFedPcaClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fedpca_examples/perform_pca/client.py b/examples/fedpca_examples/perform_pca/client.py index 26a2f596c..fa08fd882 100644 --- a/examples/fedpca_examples/perform_pca/client.py +++ b/examples/fedpca_examples/perform_pca/client.py @@ -36,7 +36,7 @@ def get_data_tensor(self, data_loader: DataLoader) -> Tensor: parser.add_argument("--seed", action="store", type=int, help="Random seed for this client.") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) components_save_path = Path(args.components_save_path) seed = args.seed @@ -45,5 +45,5 @@ def get_data_tensor(self, data_loader: DataLoader) -> Tensor: # the data used in the dim_reduction example, then both examples # should use the same random seed. set_all_random_seeds(seed) - client = MnistFedPCAClient(data_path=data_path, device=DEVICE, model_save_path=components_save_path) + client = MnistFedPCAClient(data_path=data_path, device=device, model_save_path=components_save_path) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) diff --git a/examples/fedper_example/client.py b/examples/fedper_example/client.py index 921f0c72e..990e4ea14 100644 --- a/examples/fedper_example/client.py +++ b/examples/fedper_example/client.py @@ -62,9 +62,9 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) minority_numbers = {int(number) for number in args.minority_numbers} - client = MnistFedPerClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) + client = MnistFedPerClient(data_path, [Accuracy("accuracy")], device, minority_numbers) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fedpm_example/client.py b/examples/fedpm_example/client.py index 52d1b5bfa..3e59bd56c 100644 --- a/examples/fedpm_example/client.py +++ b/examples/fedpm_example/client.py @@ -54,9 +54,9 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) minority_numbers = {int(number) for number in args.minority_numbers} - client = MnistFedPmClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) + client = MnistFedPmClient(data_path, [Accuracy("accuracy")], device, minority_numbers) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fedprox_example/client.py b/examples/fedprox_example/client.py index a8871d03a..fe67cc53f 100644 --- a/examples/fedprox_example/client.py +++ b/examples/fedprox_example/client.py @@ -57,15 +57,15 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + client = MnistFedProxClient(data_path, [Accuracy()], device, reporters=[JsonReporter()]) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully diff --git a/examples/fedrep_example/client.py b/examples/fedrep_example/client.py index 5fb94daab..bb5d4479b 100644 --- a/examples/fedrep_example/client.py +++ b/examples/fedrep_example/client.py @@ -55,8 +55,8 @@ def get_criterion(self, config: Config) -> _Loss: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_dir = Path(args.dataset_path) - client = CifarFedRepClient(data_dir, [Accuracy("accuracy")], DEVICE) + client = CifarFedRepClient(data_dir, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py index 6ba51c839..8a777b39e 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py @@ -68,8 +68,8 @@ def get_model(self, config: Config) -> nn.Module: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE) + client = CifarClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py index 6ba4d00c6..07ab98aeb 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py @@ -91,8 +91,8 @@ def transform_target(self, target: TorchTargetType) -> TorchTargetType: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = SslCifarClient(data_path, [], DEVICE) + client = SslCifarClient(data_path, [], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index 5d555282b..8c7da0221 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -94,9 +94,9 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility @@ -121,7 +121,7 @@ def get_criterion(self, config: Config) -> _Loss: client = MnistFendaDittoClient( data_path, [Accuracy()], - DEVICE, + device, args.checkpoint_path, checkpointer=checkpointer, reporters=[JsonReporter()], diff --git a/examples/fenda_example/client.py b/examples/fenda_example/client.py index ef8683357..fbcd84f87 100644 --- a/examples/fenda_example/client.py +++ b/examples/fenda_example/client.py @@ -59,9 +59,9 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) minority_numbers = {int(number) for number in args.minority_numbers} - client = MnistFendaClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) + client = MnistFendaClient(data_path, [Accuracy("accuracy")], device, minority_numbers) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/fl_plus_local_ft_example/client.py b/examples/fl_plus_local_ft_example/client.py index 346129340..1e74ef5c4 100644 --- a/examples/fl_plus_local_ft_example/client.py +++ b/examples/fl_plus_local_ft_example/client.py @@ -40,10 +40,10 @@ def get_criterion(self, config: Config) -> _Loss: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) metrics = [Accuracy("accuracy")] - client = CifarClient(data_path, metrics, DEVICE) + client = CifarClient(data_path, metrics, device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) # Run further local training after the federated learning has finished diff --git a/examples/flash_example/client.py b/examples/flash_example/client.py index 14822e625..ea47fbe3f 100644 --- a/examples/flash_example/client.py +++ b/examples/flash_example/client.py @@ -40,8 +40,8 @@ def get_criterion(self, config: Config) -> _Loss: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CifarFlashClient(data_path, [Accuracy("accuracy")], DEVICE) + client = CifarFlashClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/moon_example/client.py b/examples/moon_example/client.py index b3d5d73ea..4813ce48a 100644 --- a/examples/moon_example/client.py +++ b/examples/moon_example/client.py @@ -57,9 +57,9 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) minority_numbers = {int(number) for number in args.minority_numbers} - client = MnistMoonClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) + client = MnistMoonClient(data_path, [Accuracy("accuracy")], device, minority_numbers) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/mr_mtl_example/client.py b/examples/mr_mtl_example/client.py index c8eda73b9..833b9ee97 100644 --- a/examples/mr_mtl_example/client.py +++ b/examples/mr_mtl_example/client.py @@ -57,15 +57,15 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistMrMtlClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + client = MnistMrMtlClient(data_path, [Accuracy()], device, reporters=[JsonReporter()]) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/examples/nnunet_example/client.py b/examples/nnunet_example/client.py index 5df51f513..9cdd4d513 100644 --- a/examples/nnunet_example/client.py +++ b/examples/nnunet_example/client.py @@ -37,8 +37,8 @@ def main( client_name: Optional[str] = None, ) -> None: # Log device and server address - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Using device: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Using device: {device}") log(INFO, f"Using server address: {server_address}") # Load the dataset if necessary @@ -63,7 +63,7 @@ def main( name="Pseudo DICE", metric=GeneralizedDiceScore( num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False - ).to(DEVICE), + ).to(device), ), pred_transforms=[torch.sigmoid, get_segs_from_probs], ) @@ -77,7 +77,7 @@ def main( verbose=verbose, compile=compile, # BaseClient Args - device=DEVICE, + device=device, metrics=[dice], progress_bar=verbose, intermediate_client_state_dir=( diff --git a/examples/perfcl_example/client.py b/examples/perfcl_example/client.py index f36486f09..c659f453a 100644 --- a/examples/perfcl_example/client.py +++ b/examples/perfcl_example/client.py @@ -65,9 +65,9 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) minority_numbers = {int(number) for number in args.minority_numbers} - client = MnistPerFclClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) + client = MnistPerFclClient(data_path, [Accuracy("accuracy")], device, minority_numbers) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/scaffold_example/client.py b/examples/scaffold_example/client.py index 19a9f4eb8..2974b1961 100644 --- a/examples/scaffold_example/client.py +++ b/examples/scaffold_example/client.py @@ -50,12 +50,12 @@ def get_criterion(self, config: Config) -> _Loss: args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistScaffoldClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + client = MnistScaffoldClient(data_path, [Accuracy()], device, reporters=[JsonReporter()]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/sparse_tensor_partial_exchange_example/client.py b/examples/sparse_tensor_partial_exchange_example/client.py index a961d2748..cff31dd85 100644 --- a/examples/sparse_tensor_partial_exchange_example/client.py +++ b/examples/sparse_tensor_partial_exchange_example/client.py @@ -51,8 +51,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - client = CifarSparseCooTensorClient(data_path, [Accuracy("accuracy")], DEVICE) + client = CifarSparseCooTensorClient(data_path, [Accuracy("accuracy")], device) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/warm_up_example/fedavg_warm_up/client.py b/examples/warm_up_example/fedavg_warm_up/client.py index d3bee96f5..243ef2fb1 100644 --- a/examples/warm_up_example/fedavg_warm_up/client.py +++ b/examples/warm_up_example/fedavg_warm_up/client.py @@ -84,16 +84,16 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility set_all_random_seeds(args.seed) # Start the client - client = MnistFedAvgClient(data_path, [Accuracy()], DEVICE, checkpoint_dir=args.checkpoint_dir) + client = MnistFedAvgClient(data_path, [Accuracy()], device, checkpoint_dir=args.checkpoint_dir) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully diff --git a/examples/warm_up_example/warmed_up_fedprox/client.py b/examples/warm_up_example/warmed_up_fedprox/client.py index fd2116028..8ecd22345 100644 --- a/examples/warm_up_example/warmed_up_fedprox/client.py +++ b/examples/warm_up_example/warmed_up_fedprox/client.py @@ -100,11 +100,11 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) pretrained_model_dir = Path(args.pretrained_model_dir) weights_mapping_path = Path(args.weights_mapping_path) if args.weights_mapping_path else None - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility @@ -114,7 +114,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> client = MnistFedProxClient( data_path, [Accuracy()], - DEVICE, + device, pretrained_model_dir, weights_mapping_path, ) diff --git a/examples/warm_up_example/warmed_up_fenda/client.py b/examples/warm_up_example/warmed_up_fenda/client.py index c81e90dca..3d65a641d 100644 --- a/examples/warm_up_example/warmed_up_fenda/client.py +++ b/examples/warm_up_example/warmed_up_fenda/client.py @@ -103,11 +103,11 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) pretrained_model_dir = Path(args.pretrained_model_dir) weights_mapping_path = Path(args.weights_mapping_path) if args.weights_mapping_path else None - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") # Set the random seed for reproducibility @@ -117,7 +117,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> client = MnistFendaClient( data_path, [Accuracy("accuracy")], - DEVICE, + device, pretrained_model_dir, weights_mapping_path, ) diff --git a/fl4health/strategies/aggregate_utils.py b/fl4health/strategies/aggregate_utils.py index bbe866cbc..ecef5eaeb 100644 --- a/fl4health/strategies/aggregate_utils.py +++ b/fl4health/strategies/aggregate_utils.py @@ -47,6 +47,8 @@ def aggregate_losses(results: List[Tuple[int, float]], weighted: bool = True) -> Returns: float: the weighted or unweighted average of the loss values in the results list. """ + # Sorting the results by the loss values for numerical fluctuation determinism of the sum + results = sorted(results, key=lambda x: x[1]) if weighted: # uses flwr implementation of weighted loss averaging return weighted_loss_avg(results) diff --git a/fl4health/strategies/basic_fedavg.py b/fl4health/strategies/basic_fedavg.py index c45a03320..461de952c 100644 --- a/fl4health/strategies/basic_fedavg.py +++ b/fl4health/strategies/basic_fedavg.py @@ -12,7 +12,6 @@ Parameters, Scalar, ndarrays_to_parameters, - parameters_to_ndarrays, ) from flwr.common.logger import log from flwr.server.client_manager import ClientManager @@ -23,6 +22,7 @@ from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results from fl4health.strategies.strategy_with_poll import StrategyWithPolling +from fl4health.utils.functions import decode_and_pseudo_sort_results from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -248,12 +248,15 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] + # Aggregate them in a weighted or unweighted fashion based on settings. - aggregated_arrays = aggregate_results(weights_results, self.weighted_aggregation) + aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation) # Convert back to parameters parameters_aggregated = ndarrays_to_parameters(aggregated_arrays) diff --git a/fl4health/strategies/client_dp_fedavgm.py b/fl4health/strategies/client_dp_fedavgm.py index 48276c160..ae1066a5f 100644 --- a/fl4health/strategies/client_dp_fedavgm.py +++ b/fl4health/strategies/client_dp_fedavgm.py @@ -27,6 +27,7 @@ gaussian_noisy_unweighted_aggregate, gaussian_noisy_weighted_aggregate, ) +from fl4health.utils.functions import decode_and_pseudo_sort_results class ClientLevelDPFedAvgM(BasicFedAvg): @@ -195,13 +196,17 @@ def split_model_weights_and_clipping_bits( Tuple[List[Tuple[NDArrays, int]], NDArrays]: The first tuple is the set of (weights, training counts) per client. The second is a set of clipping bits, one for each client. """ + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) + ] + weights_and_counts: List[Tuple[NDArrays, int]] = [] clipping_bits: NDArrays = [] - for _, fit_res in results: - sample_count = fit_res.num_examples - updated_weights, clipping_bit = self.parameter_packer.unpack_parameters( - parameters_to_ndarrays(fit_res.parameters) - ) + for weights, sample_count in decoded_and_sorted_results: + updated_weights, clipping_bit = self.parameter_packer.unpack_parameters(weights) weights_and_counts.append((updated_weights, sample_count)) clipping_bits.append(np.array(clipping_bit)) diff --git a/fl4health/strategies/fedavg_dynamic_layer.py b/fl4health/strategies/fedavg_dynamic_layer.py index 08f0999bf..c6ab977b7 100644 --- a/fl4health/strategies/fedavg_dynamic_layer.py +++ b/fl4health/strategies/fedavg_dynamic_layer.py @@ -4,20 +4,14 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import ( - MetricsAggregationFn, - NDArray, - NDArrays, - Parameters, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import FitRes, Scalar from flwr.server.client_proxy import ClientProxy from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithLayerNames from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedAvgDynamicLayer(BasicFedAvg): @@ -124,13 +118,17 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + # Convert client layer weights and names into ndarrays - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] # For each layer of the model, perform weighted average of all received weights from clients - aggregated_params = self.aggregate(weights_results) + aggregated_params = self.aggregate(decoded_and_sorted_results) weights_names = [] weights = [] diff --git a/fl4health/strategies/fedavg_sparse_coo_tensor.py b/fl4health/strategies/fedavg_sparse_coo_tensor.py index d11d9d295..be8c5d922 100644 --- a/fl4health/strategies/fedavg_sparse_coo_tensor.py +++ b/fl4health/strategies/fedavg_sparse_coo_tensor.py @@ -4,7 +4,7 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union import torch -from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import FitRes, Scalar from flwr.server.client_proxy import ClientProxy @@ -12,6 +12,7 @@ from fl4health.parameter_exchange.parameter_packer import SparseCooParameterPacker from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedAvgSparseCooTensor(BasicFedAvg): @@ -137,13 +138,17 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + # Convert client tensor weights and names into ndarrays - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] # For each tensor of the model, perform weighted average of all received weights from clients - aggregated_tensors = self.aggregate(weights_results) + aggregated_tensors = self.aggregate(decoded_and_sorted_results) tensor_names = [] selected_parameters_all_tensors = [] diff --git a/fl4health/strategies/fedavg_with_adaptive_constraint.py b/fl4health/strategies/fedavg_with_adaptive_constraint.py index 735fcdd62..b438f97e4 100644 --- a/fl4health/strategies/fedavg_with_adaptive_constraint.py +++ b/fl4health/strategies/fedavg_with_adaptive_constraint.py @@ -10,6 +10,7 @@ from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedAvgWithAdaptiveConstraint(BasicFedAvg): @@ -157,14 +158,18 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) + ] + # Convert results with packed params of model weights and training loss weights_and_counts: List[Tuple[NDArrays, int]] = [] train_losses_and_counts: List[Tuple[int, float]] = [] - for _, fit_res in results: - sample_count = fit_res.num_examples - updated_weights, train_loss = self.parameter_packer.unpack_parameters( - parameters_to_ndarrays(fit_res.parameters) - ) + for weights, sample_count in decoded_and_sorted_results: + updated_weights, train_loss = self.parameter_packer.unpack_parameters(weights) weights_and_counts.append((updated_weights, sample_count)) train_losses_and_counts.append((sample_count, train_loss)) diff --git a/fl4health/strategies/feddg_ga.py b/fl4health/strategies/feddg_ga.py index 97e90f16a..108758707 100644 --- a/fl4health/strategies/feddg_ga.py +++ b/fl4health/strategies/feddg_ga.py @@ -1,16 +1,9 @@ from enum import Enum -from logging import WARNING +from logging import INFO, WARNING from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import ( - EvaluateIns, - MetricsAggregationFn, - NDArrays, - Parameters, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from flwr.common import EvaluateIns, MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import EvaluateRes, FitIns, FitRes, Scalar from flwr.server.client_manager import ClientManager @@ -18,6 +11,7 @@ from flwr.server.strategy import FedAvg from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.utils.functions import decode_and_pseudo_sort_results class SignalForTypeException(Exception): @@ -90,6 +84,9 @@ def __init__( if signal is None: self.signal = FairnessMetricType.signal_for_type(metric_type) + def __str__(self) -> str: + return f"Metric Type: {self.metric_type}, Metric Name: '{self.metric_name}', Signal: {self.signal}" + class FedDgGa(FedAvg): def __init__( @@ -184,6 +181,9 @@ def __init__( 0 < self.adjustment_weight_step_size < 1 ), f"adjustment_weight_step_size has to be between 0 and 1 ({self.adjustment_weight_step_size})" + log(INFO, f"FedDG-GA Strategy initialized with weight_step_size of {self.adjustment_weight_step_size}") + log(INFO, f"FedDG-GA Strategy initialized with FairnessMetric {self.fairness_metric}") + self.train_metrics: Dict[str, Dict[str, Scalar]] = {} self.evaluation_metrics: Dict[str, Dict[str, Scalar]] = {} self.num_rounds: Optional[int] = None @@ -323,6 +323,7 @@ def aggregate_evaluate( (Tuple[Optional[float], Dict[str, Scalar]]) A tuple containing the aggregated evaluation loss and the aggregated evaluation metrics. """ + loss_aggregated, metrics_aggregated = super().aggregate_evaluate(server_round, results, failures) self.evaluation_metrics = {} @@ -334,6 +335,7 @@ def aggregate_evaluate( # Updating the weights at the end of the training round cids = [client_proxy.cid for client_proxy, _ in results] + log(INFO, "Updating the Generalization Adjustment Weights") self.update_weights_by_ga(server_round, cids) return loss_aggregated, metrics_aggregated @@ -349,8 +351,20 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]] (NDArrays) the weighted and aggregated results. """ + if self.adjustment_weights: + log(INFO, f"Current adjustment weights by Client ID (CID) are {self.adjustment_weights}") + else: + # If the adjustment weights dictionary doesn't exist, it means that it hasn't been initialized + # and will be below. + log(INFO, f"Current adjustment weights are all initialized to {self.initial_adjustment_weight}") + + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = decode_and_pseudo_sort_results(results) + aggregated_results: Optional[NDArrays] = None - for client_proxy, fit_res in results: + for client_proxy, weights, _ in decoded_and_sorted_results: cid = client_proxy.cid # initializing adjustment weights for this client if they don't exist yet @@ -359,12 +373,14 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]] self.adjustment_weights[cid] = self.initial_adjustment_weight # apply adjustment weights - weighted_client_parameters = parameters_to_ndarrays(fit_res.parameters) + weighted_client_parameters = weights for i in range(len(weighted_client_parameters)): weighted_client_parameters[i] = weighted_client_parameters[i] * self.adjustment_weights[cid] # sum weighted parameters if aggregated_results is None: + # If this is the first client we're applying adjustment to, we set the results to those parameters. + # Remaining client parameters will be subsequently added to these. aggregated_results = weighted_client_parameters else: assert len(weighted_client_parameters) == len(aggregated_results) @@ -398,11 +414,19 @@ def update_weights_by_ga(self, server_round: int, cids: List[str]) -> None: generalization_gaps.append(global_model_metric_value - local_model_metric_value) + log( + INFO, + "Client ID (CID) and Generalization Gaps (G_{{hat{{D_i}}}}(theta^r)): " + f"{list(zip(cids, generalization_gaps))}", + ) + # Calculating the normalized generalization gaps generalization_gaps_ndarray = np.array(generalization_gaps) mean_generalization_gap = np.mean(generalization_gaps_ndarray) var_generalization_gaps = generalization_gaps_ndarray - mean_generalization_gap max_var_generalization_gap = np.max(np.abs(var_generalization_gaps)) + log(INFO, f"Mean Generalization Gap (mu): {mean_generalization_gap}") + log(INFO, f"Max Absolute Deviation of Generalization Gaps: {max_var_generalization_gap}") if max_var_generalization_gap == 0: log( @@ -435,6 +459,7 @@ def update_weights_by_ga(self, server_round: int, cids: List[str]) -> None: for cid in cids: self.adjustment_weights[cid] /= new_total_weight + log(INFO, f"New Generalization Adjustment Weights by Client ID (CID) are {self.adjustment_weights}") def get_current_weight_step_size(self, server_round: int) -> float: """ @@ -451,6 +476,9 @@ def get_current_weight_step_size(self, server_round: int) -> float: assert self.num_rounds is not None weight_step_size_decay = self.adjustment_weight_step_size / self.num_rounds weight_step_size_for_round = self.adjustment_weight_step_size - ((server_round - 1) * weight_step_size_decay) + log( + INFO, f"Step size for round: {weight_step_size_for_round}, original was {self.adjustment_weight_step_size}" + ) # Omitting an additional scaler here that is present in the reference # implementation but not in the paper: diff --git a/fl4health/strategies/fedpca.py b/fl4health/strategies/fedpca.py index a0e4c5979..97ba6f8b5 100644 --- a/fl4health/strategies/fedpca.py +++ b/fl4health/strategies/fedpca.py @@ -2,19 +2,13 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import ( - MetricsAggregationFn, - NDArray, - NDArrays, - Parameters, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import FitRes, Scalar from flwr.server.client_proxy import ClientProxy from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedPCA(BasicFedAvg): @@ -122,10 +116,14 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [weights for _, weights, _ in decode_and_pseudo_sort_results(results)] + client_singular_values = [] client_singular_vectors = [] - for _, fit_res in results: - A = parameters_to_ndarrays(fit_res.parameters) + for A in decoded_and_sorted_results: singular_vectors, singular_values = A[0], A[1] client_singular_vectors.append(singular_vectors) client_singular_values.append(singular_values) diff --git a/fl4health/strategies/flash.py b/fl4health/strategies/flash.py index 789226c47..97fbe8dae 100644 --- a/fl4health/strategies/flash.py +++ b/fl4health/strategies/flash.py @@ -149,9 +149,11 @@ def aggregate_fit( failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using the Flash method.""" + fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures ) + if fedavg_parameters_aggregated is None: return None, {} diff --git a/fl4health/strategies/model_merge_strategy.py b/fl4health/strategies/model_merge_strategy.py index c30e723e9..cd5e06cf3 100644 --- a/fl4health/strategies/model_merge_strategy.py +++ b/fl4health/strategies/model_merge_strategy.py @@ -20,6 +20,7 @@ from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager from fl4health.strategies.aggregate_utils import aggregate_results +from fl4health.utils.functions import decode_and_pseudo_sort_results class ModelMergeStrategy(Strategy): @@ -188,12 +189,15 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] + # Aggregate them in an weighted or unweighted fashion based on self.weighted_aggregation. - aggregated_arrays = aggregate_results(weights_results, self.weighted_aggregation) + aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation) # Convert back to parameters parameters_aggregated = ndarrays_to_parameters(aggregated_arrays) @@ -246,7 +250,7 @@ def aggregate_evaluate( def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[float, Dict[str, Scalar]]]: """ - Evaluate the model parameters after the merging has occured. This function can be used to perform centralized + Evaluate the model parameters after the merging has occurred. This function can be used to perform centralized (i.e., server-side) evaluation of model parameters. Args: diff --git a/fl4health/strategies/scaffold.py b/fl4health/strategies/scaffold.py index 842f56630..df2f8edf2 100644 --- a/fl4health/strategies/scaffold.py +++ b/fl4health/strategies/scaffold.py @@ -21,6 +21,7 @@ from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -179,12 +180,14 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results with packed params of model weights and client control variate updates - updated_params = [parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results] + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [weights for _, weights, _ in decode_and_pseudo_sort_results(results)] # x = 1 / |S| * sum(x_i) and c = 1 / |S| * sum(delta_c_i) # Aggregation operation over packed params (includes both weights and control variate updates) - aggregated_params = self.aggregate(updated_params) + aggregated_params = self.aggregate(decoded_and_sorted_results) weights, control_variates_update = self.parameter_packer.unpack_parameters(aggregated_params) diff --git a/fl4health/utils/functions.py b/fl4health/utils/functions.py index a546b42f9..460aa949e 100644 --- a/fl4health/utils/functions.py +++ b/fl4health/utils/functions.py @@ -1,6 +1,10 @@ -from typing import Any, Tuple +from typing import Any, List, Tuple +import numpy as np import torch +from flwr.common import parameters_to_ndarrays +from flwr.common.typing import FitRes, NDArrays +from flwr.server.client_proxy import ClientProxy class BernoulliSample(torch.autograd.Function): @@ -41,3 +45,65 @@ def backward(ctx: torch.Any, grad_output: torch.Tensor) -> torch.Tensor: # type def sigmoid_inverse(x: torch.Tensor) -> torch.Tensor: return -torch.log(1 / x - 1) + + +def select_zeroeth_element(array: np.ndarray) -> float: + """ + Helper function that simply selects the first element of an array (index 0 across all dimensions). + + Args: + array (np.ndarray): Array from which the very first element is selected + + Returns: + float: zeroeth element value. + """ + indices = tuple(0 for _ in array.shape) + return array[indices] + + +def pseudo_sort_scoring_function(client_result: Tuple[ClientProxy, NDArrays, int]) -> float: + """ + This function provides the "score" that is used to sort a list of Tuple[ClientProxy, NDArrays, int]. We select + the zeroeth (index 0 across all dimensions) element from each of the arrays in the NDArrays list, sum them, and + add the integer (client sample counts) to the sum to come up with a score for sorting. Note that + the underlying numpy arrays in NDArrays may not all be of numerical type. So we limit to selecting elements from + arrays of floats. + + Args: + client_result (Tuple[ClientProxy, NDArrays, int]]): Elements to use to determine the score. + + Returns: + float: Sum of a the zeroeth elements of each array in the NDArrays and the int of the tuple + """ + _, client_arrays, sample_count = client_result + zeroeth_params = [ + select_zeroeth_element(array) for array in client_arrays if np.issubdtype(array.dtype, np.floating) + ] + return np.sum(zeroeth_params) + sample_count + + +def decode_and_pseudo_sort_results( + results: List[Tuple[ClientProxy, FitRes]] +) -> List[Tuple[ClientProxy, NDArrays, int]]: + """ + This function is used to convert the results of client training into NDArrays and to apply a pseudo sort + based on the zeroeth elements in the weights and the sample counts. As long as the numpy seed has been set on the + server this process should be deterministic when repeatedly running the same server code leading to deterministic + sorting (assuming the clients are deterministically training their weights as well). This allows, for example, + for weights from the clients to be summed in a deterministic order during aggregation. + + NOTE: Client proxies would be nice to use for this task, but the CIDs are set by uuid deep in the flower library + and are, therefore, not pinnable without a ton of work. + + Args: + results (List[Tuple[ClientProxy, FitRes]]): Results from a federated training round. + + Returns: + List[Tuple[ClientProxy, NDArrays, int]]: The ordered set of weights as NDarrays and the corresponding + number of examples + """ + ndarrays_results = [ + (client_proxy, parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for client_proxy, fit_res in results + ] + return sorted(ndarrays_results, key=lambda x: pseudo_sort_scoring_function(x)) diff --git a/research/ag_news/dynamic_layer_exchange/client.py b/research/ag_news/dynamic_layer_exchange/client.py index 130c941b6..d2b04f0b6 100644 --- a/research/ag_news/dynamic_layer_exchange/client.py +++ b/research/ag_news/dynamic_layer_exchange/client.py @@ -168,11 +168,11 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_dir) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Exchange Percentage: {args.exchange_percentage}") @@ -185,7 +185,7 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ client = BertDynamicLayerExchangeClient( data_path, [Accuracy("accuracy")], - DEVICE, + device, learning_rate=args.learning_rate, exchange_percentage=args.exchange_percentage, norm_threshold=args.norm_threshold, diff --git a/research/ag_news/sparse_tensor_exchange/client.py b/research/ag_news/sparse_tensor_exchange/client.py index 003ca420e..97010afd5 100644 --- a/research/ag_news/sparse_tensor_exchange/client.py +++ b/research/ag_news/sparse_tensor_exchange/client.py @@ -137,11 +137,11 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_dir) - log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Sparsity Level: {args.sparsity_level}") @@ -154,7 +154,7 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ client = BertSparseTensorExchangeClient( data_path, [Accuracy("accuracy")], - DEVICE, + device, learning_rate=args.learning_rate, sparsity_level=args.sparsity_level, checkpointer=checkpointer, diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index a2e510a7c..7759d93ca 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -123,8 +123,8 @@ def get_model(self, config: Config) -> nn.Module: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -157,7 +157,7 @@ def get_model(self, config: Config) -> nn.Module: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index 624e7e639..720cc0011 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -127,8 +127,8 @@ def get_model(self, config: Config) -> nn.Module: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -161,7 +161,7 @@ def get_model(self, config: Config) -> nn.Module: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index 421fba250..947fad771 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -135,8 +135,8 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -171,7 +171,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index 0cc5e1939..f0b38fe0e 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -121,8 +121,8 @@ def get_model(self, config: Config) -> nn.Module: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -155,7 +155,7 @@ def get_model(self, config: Config) -> nn.Module: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/ditto/client.py b/research/cifar10/ditto/client.py index 8a276b444..f4b44d12c 100644 --- a/research/cifar10/ditto/client.py +++ b/research/cifar10/ditto/client.py @@ -184,8 +184,8 @@ def get_model(self, config: Config) -> nn.Module: if args.use_partitioned_data: log(INFO, "Using preprocessed partitioned data for training, validation and testing") - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -214,7 +214,7 @@ def get_model(self, config: Config) -> nn.Module: client = CifarDittoClient( data_path=data_path, metrics=[Accuracy("accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/ditto_deep_mmd/client.py b/research/cifar10/ditto_deep_mmd/client.py index 2bd8bb512..dca688a68 100644 --- a/research/cifar10/ditto_deep_mmd/client.py +++ b/research/cifar10/ditto_deep_mmd/client.py @@ -210,8 +210,8 @@ def get_model(self, config: Config) -> nn.Module: if args.use_partitioned_data: log(INFO, "Using preprocessed partitioned data for training, validation and testing") - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Mu: {args.mu}") @@ -241,7 +241,7 @@ def get_model(self, config: Config) -> nn.Module: client = CifarDittoClient( data_path=data_path, metrics=[Accuracy("accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/ditto_mkmmd/client.py b/research/cifar10/ditto_mkmmd/client.py index 1255fda6a..7f5d5f034 100644 --- a/research/cifar10/ditto_mkmmd/client.py +++ b/research/cifar10/ditto_mkmmd/client.py @@ -226,8 +226,8 @@ def get_model(self, config: Config) -> nn.Module: if args.use_partitioned_data: log(INFO, "Using preprocessed partitioned data for training, validation and testing") - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Mu: {args.mu}") @@ -259,7 +259,7 @@ def get_model(self, config: Config) -> nn.Module: client = CifarDittoClient( data_path=data_path, metrics=[Accuracy("accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py index a2e510a7c..7759d93ca 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -123,8 +123,8 @@ def get_model(self, config: Config) -> nn.Module: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -157,7 +157,7 @@ def get_model(self, config: Config) -> nn.Module: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py index 7c8828465..4e5549b69 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -121,8 +121,8 @@ def get_model(self, config: Config) -> FendaModel: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -155,7 +155,7 @@ def get_model(self, config: Config) -> FendaModel: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py index 792178240..19f891ba3 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -135,8 +135,8 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Beta: {args.beta}") @@ -171,7 +171,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: F1("f1_score_macro", average="macro"), F1("f1_score_weight", average="weighted"), ], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/cifar10/fedavg/client.py b/research/cifar10/fedavg/client.py index 224b7f02e..c206f9c40 100644 --- a/research/cifar10/fedavg/client.py +++ b/research/cifar10/fedavg/client.py @@ -184,8 +184,8 @@ def get_model(self, config: Config) -> nn.Module: if args.use_partitioned_data: log(INFO, "Using preprocessed partitioned data for training, validation and testing") - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -213,7 +213,7 @@ def get_model(self, config: Config) -> nn.Module: client = CifarFedAvgClient( data_path=data_path, metrics=[Accuracy("accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, diff --git a/research/flamby/README.md b/research/flamby/README.md index 391aa71a7..85068a10d 100644 --- a/research/flamby/README.md +++ b/research/flamby/README.md @@ -1,20 +1,4 @@ -### Installing the Flamby dependencies from Fixed Requirements File - -__NOTE__: The standard workflow discussed by FLamby is in the next section, but is currently broken due to dependency changes. - -Create a python environment with your preferred env manager. We'll use conda below -``` bash -conda create -n flamby_fl4health python=3.10 -conda activate flamby_fl4health -``` -Install the dependencies of both FLamby and FL4Health using the fixed requirements file at `research/flamby/flamby_requirements.txt`. -```bash -cd -pip install --src -r research/flamby/flamby_requirements.txt -``` -Note that this installation will clone the FLamby repo to the path provided to --src and then install FLamby as a package. - -### Installing the Flamby dependencies (Old Workflow) +### Installing the Flamby dependencies __NOTE__: The workflow below is normally the smoothest way to construct the FLamby + FL4Health environment required to run the FLamby experiments. However, with a recent upgrade to MonAI, some of the functionality that FLamby depends on are broken. Until this is fixed, the workflow below will not work. @@ -33,8 +17,56 @@ cd pip install --upgrade pip poetry poetry install --with "dev, dev-local, test, codestyle" cd -pip install -e ".[all_extra]" +pip install -e ".[cam16, heart, isic2019, ixi, lidc, tcga]" +``` +__NOTE__: We avoid installing Fed-KITS2019, as it requires a fairly old version on nnUnet, which we no longer support in our library. + +In addition, you'll have to edit the code for `FedIXITiny` in `flamby/datasets/fed_ixi/datasets.py` replacing the following + + + + + + + + + + + + + + + +
OldNew
+ +``` python +from monai.transforms import ( + AddChannel, + ... +``` + + +``` python +from monai.transforms import ( + EnsureChannelFirst, + ... ``` +
+ +```python +default_transform = Compose( + [ToTensor(), AddChannel(), Resize(self.common_shape)] +``` + + +```python +default_transform = Compose( + [ToTensor(), EnsureChannelFirst(channel_dim="no_channel"), Resize(self.common_shape)] +) +``` +
+ +This is because AddChannel was removed in Version 1.3 of MonAI ### Downloading the Fed ISIC 2019 Dataset diff --git a/research/flamby/fed_heart_disease/apfl/client.py b/research/flamby/fed_heart_disease/apfl/client.py index 4f038849a..6862ef25b 100644 --- a/research/flamby/fed_heart_disease/apfl/client.py +++ b/research/flamby/fed_heart_disease/apfl/client.py @@ -117,8 +117,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}") @@ -137,7 +137,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseApflClient( data_path=args.dataset_dir, metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, alpha_learning_rate=args.alpha_learning_rate, diff --git a/research/flamby/fed_heart_disease/central/train.py b/research/flamby/fed_heart_disease/central/train.py index f3a3c70d4..77a405632 100644 --- a/research/flamby/fed_heart_disease/central/train.py +++ b/research/flamby/fed_heart_disease/central/train.py @@ -61,11 +61,11 @@ def __init__( ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") trainer = FedHeartDiseaseCentralizedTrainer( - DEVICE, + device, args.artifact_dir, args.dataset_dir, args.run_name, diff --git a/research/flamby/fed_heart_disease/ditto/client.py b/research/flamby/fed_heart_disease/ditto/client.py index 24feb844d..292cfdfdf 100644 --- a/research/flamby/fed_heart_disease/ditto/client.py +++ b/research/flamby/fed_heart_disease/ditto/client.py @@ -123,8 +123,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -146,7 +146,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseDittoClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/fedadam/client.py b/research/flamby/fed_heart_disease/fedadam/client.py index 54b8c5756..825049603 100644 --- a/research/flamby/fed_heart_disease/fedadam/client.py +++ b/research/flamby/fed_heart_disease/fedadam/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseFedAdamClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/fedavg/client.py b/research/flamby/fed_heart_disease/fedavg/client.py index 52c24ad45..894157270 100644 --- a/research/flamby/fed_heart_disease/fedavg/client.py +++ b/research/flamby/fed_heart_disease/fedavg/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseFedAvgClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/fedper/client.py b/research/flamby/fed_heart_disease/fedper/client.py index 13d8da309..60b9b6897 100644 --- a/research/flamby/fed_heart_disease/fedper/client.py +++ b/research/flamby/fed_heart_disease/fedper/client.py @@ -127,8 +127,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -149,7 +149,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: client = FedHeartDiseaseFedPerClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/fedprox/client.py b/research/flamby/fed_heart_disease/fedprox/client.py index 4a3316445..ce31f8ab1 100644 --- a/research/flamby/fed_heart_disease/fedprox/client.py +++ b/research/flamby/fed_heart_disease/fedprox/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseFedProxClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/fenda/client.py b/research/flamby/fed_heart_disease/fenda/client.py index e217fd2dd..63787db6a 100644 --- a/research/flamby/fed_heart_disease/fenda/client.py +++ b/research/flamby/fed_heart_disease/fenda/client.py @@ -120,8 +120,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -142,7 +142,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseFendaClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/local/train.py b/research/flamby/fed_heart_disease/local/train.py index 6120a663d..ebda25650 100644 --- a/research/flamby/fed_heart_disease/local/train.py +++ b/research/flamby/fed_heart_disease/local/train.py @@ -67,11 +67,11 @@ def __init__( ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") trainer = FedHeartDiseaseLocalTrainer( - DEVICE, + device, args.client_number, args.artifact_dir, args.dataset_dir, diff --git a/research/flamby/fed_heart_disease/moon/client.py b/research/flamby/fed_heart_disease/moon/client.py index 5563fb8ee..5da0aca18 100644 --- a/research/flamby/fed_heart_disease/moon/client.py +++ b/research/flamby/fed_heart_disease/moon/client.py @@ -124,8 +124,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -139,7 +139,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseMoonClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/perfcl/client.py b/research/flamby/fed_heart_disease/perfcl/client.py index 7c297e3f7..6862d142a 100644 --- a/research/flamby/fed_heart_disease/perfcl/client.py +++ b/research/flamby/fed_heart_disease/perfcl/client.py @@ -138,8 +138,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -160,7 +160,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseasePerFclClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_heart_disease/scaffold/client.py b/research/flamby/fed_heart_disease/scaffold/client.py index e375b133b..a5d7aff20 100644 --- a/research/flamby/fed_heart_disease/scaffold/client.py +++ b/research/flamby/fed_heart_disease/scaffold/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedHeartDiseaseScaffoldClient( data_path=Path(args.dataset_dir), metrics=[Accuracy("FedHeartDisease_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 3b4d5021c..6c622ab72 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -118,8 +118,8 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}") @@ -131,7 +131,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: client = FedIsic2019ApflClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, alpha_learning_rate=args.alpha_learning_rate, diff --git a/research/flamby/fed_isic2019/central/train.py b/research/flamby/fed_isic2019/central/train.py index 2cdf6d3e2..79f86302d 100644 --- a/research/flamby/fed_isic2019/central/train.py +++ b/research/flamby/fed_isic2019/central/train.py @@ -60,11 +60,11 @@ def __init__( ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") trainer = FedIsic2019CentralizedTrainer( - DEVICE, + device, args.artifact_dir, args.dataset_dir, args.run_name, diff --git a/research/flamby/fed_isic2019/ditto/client.py b/research/flamby/fed_isic2019/ditto/client.py index 4e1f0dd63..7b1249eae 100644 --- a/research/flamby/fed_isic2019/ditto/client.py +++ b/research/flamby/fed_isic2019/ditto/client.py @@ -118,8 +118,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -133,7 +133,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019DittoClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py index 84ba06bca..de413e0f0 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py @@ -146,8 +146,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Mu: {args.mu}") @@ -163,7 +163,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019DittoClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/client.py b/research/flamby/fed_isic2019/ditto_mkmmd/client.py index ceac45890..401afec2c 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/client.py +++ b/research/flamby/fed_isic2019/ditto_mkmmd/client.py @@ -161,8 +161,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Mu: {args.mu}") @@ -180,7 +180,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019DittoClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/fedadam/client.py b/research/flamby/fed_isic2019/fedadam/client.py index 06d879516..51a00470c 100644 --- a/research/flamby/fed_isic2019/fedadam/client.py +++ b/research/flamby/fed_isic2019/fedadam/client.py @@ -107,8 +107,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -119,7 +119,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019FedAdamClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/fedavg/client.py b/research/flamby/fed_isic2019/fedavg/client.py index afe4e48f6..a2ff5af6f 100644 --- a/research/flamby/fed_isic2019/fedavg/client.py +++ b/research/flamby/fed_isic2019/fedavg/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019FedAvgClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/fedper/client.py b/research/flamby/fed_isic2019/fedper/client.py index fe312d8a3..1a8f3564d 100644 --- a/research/flamby/fed_isic2019/fedper/client.py +++ b/research/flamby/fed_isic2019/fedper/client.py @@ -122,8 +122,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -137,7 +137,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: client = FedIsic2019FedPerClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/fedprox/client.py b/research/flamby/fed_isic2019/fedprox/client.py index 92bbe046c..de53f6a67 100644 --- a/research/flamby/fed_isic2019/fedprox/client.py +++ b/research/flamby/fed_isic2019/fedprox/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019FedProxClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/fenda/client.py b/research/flamby/fed_isic2019/fenda/client.py index 45bff2fc7..c3c1c2656 100644 --- a/research/flamby/fed_isic2019/fenda/client.py +++ b/research/flamby/fed_isic2019/fenda/client.py @@ -120,8 +120,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -142,7 +142,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019FendaClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/local/train.py b/research/flamby/fed_isic2019/local/train.py index 8c1586a07..ffcee0737 100644 --- a/research/flamby/fed_isic2019/local/train.py +++ b/research/flamby/fed_isic2019/local/train.py @@ -67,11 +67,11 @@ def __init__( ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") trainer = FedIsic2019LocalTrainer( - DEVICE, + device, args.client_number, args.artifact_dir, args.dataset_dir, diff --git a/research/flamby/fed_isic2019/moon/client.py b/research/flamby/fed_isic2019/moon/client.py index 69aae3383..6a3e8dfdf 100644 --- a/research/flamby/fed_isic2019/moon/client.py +++ b/research/flamby/fed_isic2019/moon/client.py @@ -124,8 +124,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -139,7 +139,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019MoonClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py index f84d25069..01368c84f 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py @@ -157,8 +157,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Mu: {args.mu}") @@ -176,7 +176,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019MrMtlClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/perfcl/client.py b/research/flamby/fed_isic2019/perfcl/client.py index 570cd8068..a6f7d0e80 100644 --- a/research/flamby/fed_isic2019/perfcl/client.py +++ b/research/flamby/fed_isic2019/perfcl/client.py @@ -138,8 +138,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -160,7 +160,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019PerFclClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_isic2019/scaffold/client.py b/research/flamby/fed_isic2019/scaffold/client.py index 7730815f0..94a8ae8f7 100644 --- a/research/flamby/fed_isic2019/scaffold/client.py +++ b/research/flamby/fed_isic2019/scaffold/client.py @@ -106,8 +106,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -118,7 +118,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIsic2019ScaffoldClient( data_path=Path(args.dataset_dir), metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/apfl/client.py b/research/flamby/fed_ixi/apfl/client.py index b1deb8dd1..4f1b49357 100644 --- a/research/flamby/fed_ixi/apfl/client.py +++ b/research/flamby/fed_ixi/apfl/client.py @@ -119,8 +119,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}") @@ -139,7 +139,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiApflClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, alpha_learning_rate=args.alpha_learning_rate, diff --git a/research/flamby/fed_ixi/central/train.py b/research/flamby/fed_ixi/central/train.py index 6e362d273..53e227067 100644 --- a/research/flamby/fed_ixi/central/train.py +++ b/research/flamby/fed_ixi/central/train.py @@ -64,11 +64,11 @@ def __init__( ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") trainer = FedIxiCentralizedTrainer( - DEVICE, + device, args.artifact_dir, args.dataset_dir, args.run_name, diff --git a/research/flamby/fed_ixi/ditto/client.py b/research/flamby/fed_ixi/ditto/client.py index efadaf46c..9656d7ef4 100644 --- a/research/flamby/fed_ixi/ditto/client.py +++ b/research/flamby/fed_ixi/ditto/client.py @@ -126,8 +126,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -148,7 +148,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiDittoClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/fedadam/client.py b/research/flamby/fed_ixi/fedadam/client.py index 2e662c60f..e6c6f785e 100644 --- a/research/flamby/fed_ixi/fedadam/client.py +++ b/research/flamby/fed_ixi/fedadam/client.py @@ -107,8 +107,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -119,7 +119,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiFedAdamClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/fedavg/client.py b/research/flamby/fed_ixi/fedavg/client.py index 730937a6e..4fbd24695 100644 --- a/research/flamby/fed_ixi/fedavg/client.py +++ b/research/flamby/fed_ixi/fedavg/client.py @@ -109,8 +109,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -121,7 +121,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiFedAvgClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/fedper/client.py b/research/flamby/fed_ixi/fedper/client.py index bb831f2f1..fb2c3aa78 100644 --- a/research/flamby/fed_ixi/fedper/client.py +++ b/research/flamby/fed_ixi/fedper/client.py @@ -127,8 +127,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -149,7 +149,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: client = FedIxiFedPerClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/fedprox/client.py b/research/flamby/fed_ixi/fedprox/client.py index 239caf7e4..26921012e 100644 --- a/research/flamby/fed_ixi/fedprox/client.py +++ b/research/flamby/fed_ixi/fedprox/client.py @@ -109,8 +109,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -121,7 +121,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiFedProxClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/fenda/client.py b/research/flamby/fed_ixi/fenda/client.py index 1cee48c93..876fc600a 100644 --- a/research/flamby/fed_ixi/fenda/client.py +++ b/research/flamby/fed_ixi/fenda/client.py @@ -120,8 +120,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -142,7 +142,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiFendaClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/local/train.py b/research/flamby/fed_ixi/local/train.py index c3ea74017..7716bc0b1 100644 --- a/research/flamby/fed_ixi/local/train.py +++ b/research/flamby/fed_ixi/local/train.py @@ -71,11 +71,11 @@ def __init__( ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") trainer = FedIxiLocalTrainer( - DEVICE, + device, args.client_number, args.artifact_dir, args.dataset_dir, diff --git a/research/flamby/fed_ixi/moon/client.py b/research/flamby/fed_ixi/moon/client.py index 6bdb163c1..a3b5ac5e9 100644 --- a/research/flamby/fed_ixi/moon/client.py +++ b/research/flamby/fed_ixi/moon/client.py @@ -124,8 +124,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -139,7 +139,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiMoonClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/perfcl/client.py b/research/flamby/fed_ixi/perfcl/client.py index 3779de776..e93098aa7 100644 --- a/research/flamby/fed_ixi/perfcl/client.py +++ b/research/flamby/fed_ixi/perfcl/client.py @@ -138,8 +138,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}") @@ -160,7 +160,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiPerFclClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/fed_ixi/scaffold/client.py b/research/flamby/fed_ixi/scaffold/client.py index 8c52e08b9..4cb8a6aa5 100644 --- a/research/flamby/fed_ixi/scaffold/client.py +++ b/research/flamby/fed_ixi/scaffold/client.py @@ -109,8 +109,8 @@ def get_criterion(self, config: Config) -> _Loss: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") log(INFO, f"Learning Rate: {args.learning_rate}") @@ -121,7 +121,7 @@ def get_criterion(self, config: Config) -> _Loss: client = FedIxiScaffoldClient( data_path=Path(args.dataset_dir), metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], - device=DEVICE, + device=device, client_number=args.client_number, learning_rate=args.learning_rate, checkpointer=checkpointer, diff --git a/research/flamby/flamby_requirements.txt b/research/flamby/flamby_requirements.txt deleted file mode 100644 index cf6cb2bba..000000000 --- a/research/flamby/flamby_requirements.txt +++ /dev/null @@ -1,202 +0,0 @@ -absl-py==1.0.0 -alabaster==0.7.13 -albumentations==1.3.0 -appdirs==1.4.4 -array-record==0.2.0 -astor==0.8.1 -astunparse==1.6.3 -attrs==21.4.0 -autograd==1.5 -autograd-gamma==0.5.0 -Babel==2.12.1 -batchgenerators==0.25 -black==23.3.0 -cachetools==5.3.0 -certifi==2023.5.7 -cfgv==3.3.1 -charset-normalizer==3.1.0 -click==8.1.3 -cloudpickle==2.2.1 -contourpy==1.0.7 -coverage==7.2.5 -cycler==0.11.0 -dask==2023.5.0 -decorator==5.1.1 -dicom-numpy==0.6.5 -dicom2nifti==2.4.8 -distlib==0.3.6 -dm-tree==0.1.8 -docker-pycreds==0.4.0 -docutils==0.17.1 -dp-accounting==0.4.1 -efficientnet-pytorch==0.7.1 -etils==1.3.0 -exceptiongroup==1.1.1 -filelock==3.12.0 -flake8==5.0.4 --e git+https://github.com/owkin/FLamby.git@4dfc53479ec4141849d67a6adace1137819317a2#egg=flamby -flatbuffers==23.5.9 -flwr==1.4.0 -fonttools==4.39.4 -formulaic==0.6.1 -fsspec==2023.5.0 -future==0.18.3 -gast==0.4.0 -gitdb==4.0.10 -GitPython==3.1.31 -google-api-core==2.11.0 -google-api-python-client==2.87.0 -google-auth==2.18.1 -google-auth-httplib2==0.1.0 -google-auth-oauthlib==1.0.0 -google-pasta==0.2.0 -googleapis-common-protos==1.59.0 -grpcio==1.54.2 -h5py==3.8.0 -histolab==0.6.0 -httplib2==0.22.0 -identify==2.5.24 -idna==3.4 -imageio==2.29.0 -imagesize==1.4.1 -immutabledict==2.2.4 -importlib-metadata==6.6.0 -importlib-resources==5.12.0 -iniconfig==2.0.0 -interface-meta==1.3.0 -isort==5.12.0 -iterators==0.0.2 -jax==0.4.10 -Jinja2==3.1.2 -joblib==1.2.0 -keras==2.12.0 -kiwisolver==1.4.4 -large-image==1.20.6 -large-image-source-openslide==1.20.6 -libclang==16.0.0 -lifelines==0.27.7 -linecache2==1.0.0 -llvmlite==0.40.0 -locket==1.0.0 -Markdown==3.4.3 -MarkupSafe==2.1.2 -matplotlib==3.7.1 -mccabe==0.7.0 -MedPy==0.4.0 -ml-dtypes==0.1.0 -monai==1.1.0 -mpmath==1.2.1 -mypy==1.3.0 -mypy-extensions==1.0.0 -networkx==3.1 -nibabel==3.2.2 -nltk==3.8.1 -nnunet==1.7.0 -nodeenv==1.8.0 -numba==0.57.0 -numpy==1.22.4 -oauthlib==3.2.2 -opacus==1.4.0 -opencv-python-headless==4.7.0.72 -openslide-python==1.2.0 -opt-einsum==3.3.0 -packaging==22.0 -palettable==3.3.3 -pandas==1.5.3 -parameterized==0.9.0 -partd==1.4.0 -pathspec==0.11.1 -pathtools==0.1.2 -patsy==0.5.3 -Pillow==9.5.0 -platformdirs==3.5.1 -pluggy==1.0.0 -portalocker==2.8.2 -pre-commit==3.3.2 -promise==2.3 -protobuf==3.20.3 -psutil==5.9.5 -pyasn1==0.5.0 -pyasn1-modules==0.3.0 -pycodestyle==2.9.1 -pydicom==2.3.1 -pyflakes==2.5.0 -Pygments==2.15.1 -pynndescent==0.5.10 -pyparsing==3.0.9 -pyproject-flake8==5.0.4 -pytest==7.3.1 -pytest-cov==4.0.0 -python-dateutil==2.8.2 -python-gdcm==3.0.22 -pytz==2023.3 -PyWavelets==1.4.1 -PyYAML==6.0 -qudida==0.0.4 -regex==2023.5.5 -requests==2.30.0 -requests-oauthlib==1.3.1 -rsa==4.9 -scikit-image==0.19.3 -scikit-learn==1.2.2 -scipy==1.7.3 -seaborn==0.12.2 -sentry-sdk==1.23.1 -setproctitle==1.3.2 -SimpleITK==2.2.1 -six==1.16.0 -sklearn==0.0.post5 -smmap==5.0.0 -snowballstemmer==2.2.0 -Sphinx==4.5.0 -sphinx-rtd-theme==1.0.0 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.5 -statsmodels==0.14.0 -sympy==1.12 -tensorboard==2.12.3 -tensorboard-data-server==0.7.0 -tensorflow==2.12.0 -tensorflow-datasets==4.9.2 -tensorflow-estimator==2.12.0 -tensorflow-io-gcs-filesystem==0.32.0 -tensorflow-metadata==1.13.1 -tensorflow-privacy==0.8.9 -tensorflow-probability==0.15.0 -termcolor==2.3.0 -threadpoolctl==3.1.0 -tifffile==2023.4.12 -tifftools==1.3.9 -toml==0.10.2 -tomli==2.0.1 -toolz==0.12.0 -torch==2.0.1 -torchdata==0.6.1 -torcheval==0.0.7 -torchinfo==1.8.0 -torchtext==0.15.2 -torchvision==0.15.2 -tqdm==4.65.0 -traceback2==1.4.0 -types-protobuf==4.23.0.1 -types-PyYAML==6.0.12.10 -types-requests==2.30.0.0 -types-setuptools==67.7.0.3 -types-six==1.16.21.8 -types-tabulate==0.9.0.2 -types-urllib3==1.26.25.13 -typing_extensions==4.5.0 -umap-learn==0.5.3 -unittest2==1.1.0 -uritemplate==4.1.1 -urllib3==1.26.15 -virtualenv==20.23.0 -wandb==0.15.3 -Werkzeug==2.3.4 -wget==3.2 -wrapt==1.14.1 -zipp==3.15.0 diff --git a/research/flamby/single_node_trainer.py b/research/flamby/single_node_trainer.py index ac9e352a1..0b5bee501 100644 --- a/research/flamby/single_node_trainer.py +++ b/research/flamby/single_node_trainer.py @@ -95,7 +95,7 @@ def validate(self, val_metric_mngr: MetricManager) -> None: for input, target in self.val_loader: input, target = input.to(self.device), target.to(self.device) - preds = self.model(input) + preds = {"predictions": self.model(input)} batch_loss = self.criterion(preds["predictions"], target) running_loss += batch_loss.item() val_metric_mngr.update(preds, target) diff --git a/research/gemini/apfl/client.py b/research/gemini/apfl/client.py index 7417a1e33..b46fe9fca 100644 --- a/research/gemini/apfl/client.py +++ b/research/gemini/apfl/client.py @@ -277,8 +277,8 @@ def validate( elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -286,7 +286,7 @@ def validate( data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], args.hospital_id, - DEVICE, + device, args.task, args.learning_rate, args.alpha_learning_rate, diff --git a/research/gemini/central/train.py b/research/gemini/central/train.py index 4f4383bd4..2cca3f1df 100644 --- a/research/gemini/central/train.py +++ b/research/gemini/central/train.py @@ -121,14 +121,14 @@ def main( elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") main( data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], - DEVICE, + device, args.task, args.batch_size, args.num_epochs, diff --git a/research/gemini/ditto/client.py b/research/gemini/ditto/client.py index e21c3dfc0..4f7cb305b 100644 --- a/research/gemini/ditto/client.py +++ b/research/gemini/ditto/client.py @@ -154,8 +154,8 @@ def get_criterion(self, config: Config) -> _Loss: elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -164,7 +164,7 @@ def get_criterion(self, config: Config) -> _Loss: client = GeminiDittoClient( data_path=data_path, metrics=[Binary_ROC_AUC(), Binary_F1(), Accuracy()], - device=DEVICE, + device=device, hospital_id=args.hospital_id, learning_rate=args.learning_rate, learning_task=args.task, diff --git a/research/gemini/fedavg/client.py b/research/gemini/fedavg/client.py index c44c9f88b..ab4ba1154 100644 --- a/research/gemini/fedavg/client.py +++ b/research/gemini/fedavg/client.py @@ -185,8 +185,8 @@ def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -194,7 +194,7 @@ def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], args.hospital_id, - DEVICE, + device, args.task, args.learning_rate, args.artifact_dir, diff --git a/research/gemini/fedopt/client.py b/research/gemini/fedopt/client.py index 1473b7601..8043fb457 100644 --- a/research/gemini/fedopt/client.py +++ b/research/gemini/fedopt/client.py @@ -191,8 +191,8 @@ def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -200,7 +200,7 @@ def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], args.hospital_id, - DEVICE, + device, args.task, args.learning_rate, args.artifact_dir, diff --git a/research/gemini/fedper/client.py b/research/gemini/fedper/client.py index 996a94304..3d1b82117 100644 --- a/research/gemini/fedper/client.py +++ b/research/gemini/fedper/client.py @@ -151,8 +151,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -161,7 +161,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: client = GeminiFedPerClient( data_path=data_path, metrics=[Binary_ROC_AUC(), Binary_F1(), Accuracy()], - device=DEVICE, + device=device, hospital_id=args.hospital_id, learning_rate=args.learning_rate, learning_task=args.task, diff --git a/research/gemini/fedprox/client.py b/research/gemini/fedprox/client.py index 94df15bb7..98decf7eb 100644 --- a/research/gemini/fedprox/client.py +++ b/research/gemini/fedprox/client.py @@ -144,15 +144,15 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") client = GeminiFedProxClient( data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], args.hospital_id, - DEVICE, + device, args.task, args.learning_rate, args.mu, diff --git a/research/gemini/fenda/client.py b/research/gemini/fenda/client.py index 8670fe580..a5911bfe5 100644 --- a/research/gemini/fenda/client.py +++ b/research/gemini/fenda/client.py @@ -199,8 +199,8 @@ def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -208,7 +208,7 @@ def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], args.hospital_id, - DEVICE, + device, args.task, args.learning_rate, args.artifact_dir, diff --git a/research/gemini/local/train.py b/research/gemini/local/train.py index 4abc12f3b..1b0ff09f0 100644 --- a/research/gemini/local/train.py +++ b/research/gemini/local/train.py @@ -125,14 +125,14 @@ def main( elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") main( data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], - DEVICE, + device, args.hospital_id, args.task, args.batch_size, diff --git a/research/gemini/moon/client.py b/research/gemini/moon/client.py index bc31c3d03..f6cb82d92 100644 --- a/research/gemini/moon/client.py +++ b/research/gemini/moon/client.py @@ -155,8 +155,8 @@ def get_criterion(self, config: Config) -> _Loss: elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -165,7 +165,7 @@ def get_criterion(self, config: Config) -> _Loss: client = GeminiMoonClient( data_path=data_path, metrics=[Binary_ROC_AUC(), Binary_F1(), Accuracy()], - device=DEVICE, + device=device, hospital_id=args.hospital_id, learning_rate=args.learning_rate, learning_task=args.task, diff --git a/research/gemini/perfcl/client.py b/research/gemini/perfcl/client.py index 5178326ad..b69e0bd9b 100644 --- a/research/gemini/perfcl/client.py +++ b/research/gemini/perfcl/client.py @@ -163,8 +163,8 @@ def get_criterion(self, config: Config) -> _Loss: # change based on the location of data data_path = Path("heterogeneous_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -173,7 +173,7 @@ def get_criterion(self, config: Config) -> _Loss: client = GeminiPerFclClient( data_path=data_path, metrics=[Binary_ROC_AUC(), Binary_F1(), Accuracy()], - device=DEVICE, + device=device, hospital_id=args.hospital_id, learning_rate=args.learning_rate, learning_task=args.task, diff --git a/research/gemini/scaffold/client.py b/research/gemini/scaffold/client.py index 8107e06df..d3c2b23c1 100644 --- a/research/gemini/scaffold/client.py +++ b/research/gemini/scaffold/client.py @@ -259,8 +259,8 @@ def validate(self, meter: Meter) -> Tuple[float, Dict[str, Scalar]]: elif args.task == "delirium": data_path = Path("delirium_data") - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Task: {args.task}") log(INFO, f"Server Address: {args.server_address}") @@ -268,7 +268,7 @@ def validate(self, meter: Meter) -> Tuple[float, Dict[str, Scalar]]: data_path, [Binary_ROC_AUC(), Binary_F1(), Accuracy()], args.hospital_id, - DEVICE, + device, args.task, args.learning_rate, args.artifact_dir, diff --git a/research/picai/fedavg/client.py b/research/picai/fedavg/client.py index 3fedbbf80..11a88c21b 100644 --- a/research/picai/fedavg/client.py +++ b/research/picai/fedavg/client.py @@ -150,21 +150,21 @@ def get_optimizer(self, config: Config) -> Optimizer: ) args = parser.parse_args() - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Device to be used: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") metrics = [ TorchMetric( name="MLAP", - metric=MultilabelAveragePrecision(average="macro", num_labels=2, thresholds=3).to(DEVICE), + metric=MultilabelAveragePrecision(average="macro", num_labels=2, thresholds=3).to(device), ) ] client = PicaiFedAvgClient( data_path=Path(args.base_dir), metrics=metrics, - device=DEVICE, + device=device, intermediate_client_state_dir=args.artifact_dir, overviews_dir=args.overviews_dir, data_partition=args.data_partition, diff --git a/research/picai/fl_nnunet/start_client.py b/research/picai/fl_nnunet/start_client.py index 4baa10f08..c60887ce4 100644 --- a/research/picai/fl_nnunet/start_client.py +++ b/research/picai/fl_nnunet/start_client.py @@ -37,21 +37,21 @@ def main( ) -> None: # Log device and server address - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - log(INFO, f"Using device: {DEVICE}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Using device: {device}") log(INFO, f"Using server address: {server_address}") # Define metrics dice1 = TransformsMetric( metric=TorchMetric( name="dice1", - metric=GeneralizedDiceScore(num_classes=2, weight_type="square", include_background=False).to(DEVICE), + metric=GeneralizedDiceScore(num_classes=2, weight_type="square", include_background=False).to(device), ), pred_transforms=[torch.sigmoid, get_segs_from_probs], ) # The Dice class requires preds to be ohe, but targets to not be ohe dice2 = TransformsMetric( - metric=TorchMetric(name="dice2", metric=Dice(num_classes=2, ignore_index=0).to(DEVICE)), + metric=TorchMetric(name="dice2", metric=Dice(num_classes=2, ignore_index=0).to(device)), pred_transforms=[torch.sigmoid], target_transforms=[partial(collapse_one_hot_tensor, dim=1)], ) @@ -69,7 +69,7 @@ def main( verbose=verbose, compile=compile, # BaseClient Args - device=DEVICE, + device=device, metrics=metrics, progress_bar=verbose, intermediate_client_state_dir=( diff --git a/research/picai/reporting/__init__.py b/research/picai/reporting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/picai/reporting/client.py b/research/picai/reporting/client.py index 5e0b33b33..f1fefda87 100644 --- a/research/picai/reporting/client.py +++ b/research/picai/reporting/client.py @@ -44,7 +44,7 @@ def get_model(self, config: Config) -> nn.Module: parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) reporter = WandBReporter( wandb_step_type="step", @@ -57,6 +57,6 @@ def get_model(self, config: Config) -> nn.Module: job_type="client", ) # reporter = JsonReporter() - client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE, reporters=[reporter]) + client = CifarClient(data_path, [Accuracy("accuracy")], device, reporters=[reporter]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/tests/smoke_tests/load_from_checkpoint_example/client.py b/tests/smoke_tests/load_from_checkpoint_example/client.py index 9a7719432..132e0e000 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/client.py +++ b/tests/smoke_tests/load_from_checkpoint_example/client.py @@ -98,7 +98,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict ) args = parser.parse_args() - DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) # Set the random seed for reproducibility @@ -107,7 +107,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict client = CifarClient( data_path, [Accuracy("accuracy")], - DEVICE, + device, intermediate_client_state_dir=args.intermediate_client_state_dir, client_name=args.client_name, seed=args.seed, diff --git a/tests/utils/functions_test.py b/tests/utils/functions_test.py index 8da907ebd..1f8b024b9 100644 --- a/tests/utils/functions_test.py +++ b/tests/utils/functions_test.py @@ -1,6 +1,20 @@ +from typing import List, Tuple + +import numpy as np +import pytest import torch +from flwr.common import Code, Status, ndarrays_to_parameters +from flwr.common.typing import FitRes, NDArrays +from flwr.server.client_proxy import ClientProxy -from fl4health.utils.functions import bernoulli_sample, sigmoid_inverse +from fl4health.utils.functions import ( + bernoulli_sample, + decode_and_pseudo_sort_results, + pseudo_sort_scoring_function, + select_zeroeth_element, + sigmoid_inverse, +) +from tests.test_utils.custom_client_proxy import CustomClientProxy def test_bernoulli_gradient() -> None: @@ -21,3 +35,59 @@ def test_sigmoid_inverse() -> None: z = torch.sigmoid(x) assert torch.allclose(sigmoid_inverse(z), x) torch.seed() + + +def test_select_zeroeth_element() -> None: + np.random.seed(42) + array = np.random.rand(10, 10) + random_element = select_zeroeth_element(array) + assert pytest.approx(random_element, abs=1e-5) == 0.3745401188473625 + np.random.seed(None) + + +def test_pseudo_sort_scoring_function() -> None: + np.random.seed(42) + array_list = [np.random.rand(10, 10) for _ in range(2)] + [np.random.rand(5, 5) for _ in range(2)] + sort_value = pseudo_sort_scoring_function((CustomClientProxy("c0"), array_list, 13)) + assert pytest.approx(sort_value, abs=1e-5) == 14.291990594067467 + np.random.seed(None) + + +def test_pseudo_sort_scoring_function_with_mixed_types() -> None: + np.random.seed(42) + array_list = ( + [np.random.rand(10, 10) for _ in range(2)] + + [np.array(["Cat", "Dog"]), np.array([True, False])] + + [np.random.rand(5, 5) for _ in range(2)] + ) + sort_value = pseudo_sort_scoring_function((CustomClientProxy("c0"), array_list, 13)) + assert pytest.approx(sort_value, abs=1e-5) == 14.291990594067467 + np.random.seed(None) + + +def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> FitRes: + return FitRes( + status=Status(Code.OK, ""), + parameters=ndarrays_to_parameters(parameters), + num_examples=num_examples, + metrics={"metric": metric}, + ) + + +def test_decode_and_pseudo_sort_results() -> None: + np.random.seed(42) + client0_res = construct_fit_res([np.ones((3, 3)), np.ones((4, 4))], 0.1, 100) + client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2.0)], 0.2, 75) + client2_res = construct_fit_res([np.full((3, 3), 3.0), np.full((4, 4), 3.0)], 0.3, 50) + clients_res: List[Tuple[ClientProxy, FitRes]] = [ + (CustomClientProxy("c0"), client0_res), + (CustomClientProxy("c1"), client1_res), + (CustomClientProxy("c2"), client2_res), + ] + + sorted_results = decode_and_pseudo_sort_results(clients_res) + assert sorted_results[0][2] == 50 + assert sorted_results[1][2] == 75 + assert sorted_results[2][2] == 100 + + np.random.seed(None)