Skip to content

Commit

Permalink
Merge pull request #281 from VectorInstitute/dbe/maybe_checkpoint_log…
Browse files Browse the repository at this point in the history
…ging

Adding a touch of logging for checkpointing errors
  • Loading branch information
emersodb authored Nov 11, 2024
2 parents 986350c + 8f46917 commit 6674ba4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
30 changes: 25 additions & 5 deletions fl4health/checkpointing/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from logging import INFO
from logging import ERROR, INFO
from pathlib import Path
from typing import Any, Callable, Dict, Optional

Expand Down Expand Up @@ -95,7 +95,12 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca
f"{self.comparison_str} Best score ({self.best_score})",
)
self.best_score = comparison_score
torch.save(model, self.best_checkpoint_path)
try:
log(INFO, f"Saving checkpoint as {str(self.best_checkpoint_path)}")
torch.save(model, self.best_checkpoint_path)
except Exception as e:
log(ERROR, f"Encountered the following error while saving the checkpoint: {e}")
raise e
else:
log(
INFO,
Expand All @@ -115,7 +120,12 @@ def null_score_function(loss: float, _: Dict[str, Scalar]) -> float:
def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None:
# Always checkpoint the latest model
log(INFO, "Saving latest checkpoint with LatestTorchCheckpointer")
torch.save(model, self.best_checkpoint_path)
try:
log(INFO, f"Saving checkpoint as {str(self.best_checkpoint_path)}")
torch.save(model, self.best_checkpoint_path)
except Exception as e:
log(ERROR, f"Encountered the following error while saving the checkpoint: {e}")
raise e


class BestLossTorchCheckpointer(FunctionTorchCheckpointer):
Expand Down Expand Up @@ -149,7 +159,12 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca
f"{self.comparison_str} Best Loss ({self.best_score})",
)
self.best_score = comparison_score
torch.save(model, self.best_checkpoint_path)
try:
log(INFO, f"Saving checkpoint as {str(self.best_checkpoint_path)}")
torch.save(model, self.best_checkpoint_path)
except Exception as e:
log(ERROR, f"Encountered the following error while saving the checkpoint: {e}")
raise e
else:
log(
INFO,
Expand Down Expand Up @@ -178,7 +193,12 @@ def save_checkpoint(self, checkpoint_dict: Dict[str, Any]) -> None:
checkpoint_dict (Dict[str, Any]): A dictionary with string keys and values of type
Any representing the state to checkpoint.
"""
torch.save(checkpoint_dict, self.checkpoint_path)
try:
log(INFO, f"Saving checkpoint as {self.checkpoint_path}")
torch.save(checkpoint_dict, self.checkpoint_path)
except Exception as e:
log(ERROR, f"Encountered the following error while saving the checkpoint: {e}")
raise e

def load_checkpoint(self) -> Dict[str, Any]:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_dirichlet_sampler_without_hash_key() -> None:

_, p_val_1 = chisquare(f_obs=new_samples_per_class_1, f_exp=samples_per_class)
_, p_val_2 = chisquare(f_obs=new_samples_per_class_2, f_exp=samples_per_class)
_, p_val_3 = chisquare(f_obs=new_samples_per_class_1, f_exp=new_samples_per_class_2)
_, p_val_3 = chisquare(f_obs=new_samples_per_class_2, f_exp=new_samples_per_class_1)
# Assert that the new distribution with sampler_1 is different from the original distribution
assert p_val_1 < 0.01
# Assert that the new distribution with sampler_2 is different from the original distribution
Expand Down

0 comments on commit 6674ba4

Please sign in to comment.