diff --git a/research/rxrx1/ditto_deep_mmd/client.py b/research/rxrx1/ditto_deep_mmd/client.py index 2f7a7a681..ed05ca623 100644 --- a/research/rxrx1/ditto_deep_mmd/client.py +++ b/research/rxrx1/ditto_deep_mmd/client.py @@ -57,7 +57,7 @@ def __init__( ) self.client_number = client_number self.learning_rate: float = learning_rate - + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") def setup_client(self, config: Config) -> None: diff --git a/research/rxrx1/ditto_mkmmd/client.py b/research/rxrx1/ditto_mkmmd/client.py index a0da89264..2cb78924e 100644 --- a/research/rxrx1/ditto_mkmmd/client.py +++ b/research/rxrx1/ditto_mkmmd/client.py @@ -23,7 +23,7 @@ from fl4health.utils.random import set_all_random_seeds from research.rxrx1.data.data_utils import load_rxrx1_data, load_rxrx1_test_data -BASELINE_LAYERS = ["layer1","layer2", "layer3", "layer4", "avgpool"] +BASELINE_LAYERS = ["layer1", "layer2", "layer3", "layer4", "avgpool"] class Rxrx1DittoClient(DittoMkMmdClient): @@ -54,7 +54,7 @@ def __init__( ) self.client_number = client_number self.learning_rate: float = learning_rate - + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") # Number of batches to accumulate before updating the global model