Skip to content

Commit

Permalink
PredictWriter can choose predict() key (#267)
Browse files Browse the repository at this point in the history
* Configurable PredictWriter key and error messages

* Lint

* Add key to docstring
  • Loading branch information
sjfleming authored Jan 3, 2025
1 parent 7277d76 commit 3677c22
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions cellarium/ml/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def write_prediction(
prediction: torch.Tensor,
ids: np.ndarray,
obs_names_n: np.ndarray,
output_dir: Path | str,
postfix: int | str,
) -> None:
Expand All @@ -23,7 +23,7 @@ def write_prediction(
Args:
prediction:
The prediction to write.
ids:
obs_names_n:
The IDs of the cells.
output_dir:
The directory to write the prediction to.
Expand All @@ -33,7 +33,7 @@ def write_prediction(
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame(prediction.cpu())
df.insert(0, "db_ids", ids)
df.insert(0, "obs_names_n", obs_names_n)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv")
df.to_csv(output_path, header=False, index=False)

Expand All @@ -54,12 +54,20 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):
prediction_size:
The size of the prediction. If ``None``, the entire prediction will be
written. If not ``None``, only the first ``prediction_size`` columns will be written.
key:
PredictionWriter will write this key from the output of `predict()`.
"""

def __init__(self, output_dir: Path | str, prediction_size: int | None = None) -> None:
def __init__(
self,
output_dir: Path | str,
prediction_size: int | None = None,
key: str = "x_ng",
) -> None:
super().__init__(write_interval="batch")
self.output_dir = output_dir
self.prediction_size = prediction_size
self.key = key

def write_on_batch_end(
self,
Expand All @@ -71,14 +79,24 @@ def write_on_batch_end(
batch_idx: int,
dataloader_idx: int,
) -> None:
x_ng = prediction["x_ng"]
if self.key not in batch.keys():
raise ValueError(
f"PredictionWriter callback specified the key '{self.key}' as the relevant output of `predict()`,"
" but the key is not present. Specify a different key as an input argument to the callback, or"
" modify the output keys of `predict()`."
)
prediction_np = prediction[self.key]
if self.prediction_size is not None:
x_ng = x_ng[:, : self.prediction_size]
prediction_np = prediction_np[:, : self.prediction_size]

if "obs_names_n" not in batch.keys():
raise ValueError(
"PredictionWriter callback requires the batch_key 'obs_names_n'. Add this to the YAML config."
)
assert isinstance(batch["obs_names_n"], np.ndarray)
write_prediction(
prediction=x_ng,
ids=batch["obs_names_n"],
prediction=prediction_np,
obs_names_n=batch["obs_names_n"],
output_dir=self.output_dir,
postfix=batch_idx * trainer.world_size + trainer.global_rank,
)

0 comments on commit 3677c22

Please sign in to comment.