Skip to content

Commit

Permalink
PredictionWriter: optional gzip, use ThreadPoolExecutor (#286)
Browse files Browse the repository at this point in the history
* Optional gzip, use ThreadPoolExecutor

* max_threadpool_workers default 8

* Bound the queue of the thread pool

* Add (untested) disk space check to fail fast

* linting

* linting again

* return_predictions=False as per PredictionWriter docstring

* Revert hard-coded return_predictions; issue warning

* Remove disk space checking

* Remove obsolete attribute

* Improve docstring

* Test demonstrating unintended warning with config file

* Fix problem with test

* Fix the UserWarning location

* Remove extra newline
  • Loading branch information
sjfleming authored Jan 23, 2025
1 parent 090161e commit 8e7c817
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 3 deletions.
76 changes: 73 additions & 3 deletions cellarium/ml/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import os
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from queue import Queue

import lightning.pytorch as pl
import numpy as np
Expand All @@ -16,6 +18,8 @@ def write_prediction(
obs_names_n: np.ndarray,
output_dir: Path | str,
postfix: int | str,
gzip: bool = True,
executor: ThreadPoolExecutor | None = None,
) -> None:
"""
Write prediction to a CSV file.
Expand All @@ -29,13 +33,51 @@ def write_prediction(
The directory to write the prediction to.
postfix:
A postfix to add to the CSV file name.
gzip:
Whether to compress the CSV file using gzip.
executor:
The executor used to write the prediction. If ``None``, no executor will be used.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame(prediction.cpu())
df.insert(0, "obs_names_n", obs_names_n)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv")
df.to_csv(output_path, header=False, index=False)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv" + (".gz" if gzip else ""))
to_csv_kwargs: dict[str, str | bool] = {"header": False, "index": False}
if gzip:
to_csv_kwargs |= {"compression": "gzip"}

def _write_csv(frame: pd.DataFrame, path: str) -> None:
frame.to_csv(path, **to_csv_kwargs)

if executor is None:
_write_csv(df, output_path)
else:
executor.submit(_write_csv, df, output_path)


class BoundedThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor with a bounded queue for task submissions.
This class is used to prevent the queue from growing indefinitely when tasks are submitted,
which can lead to an out-of-memory error.
"""

def __init__(self, max_workers: int, max_queue_size: int):
# Use a bounded queue for task submissions
self._queue: Queue = Queue(max_queue_size)
super().__init__(max_workers=max_workers)

def submit(self, fn, /, *args, **kwargs):
# Block if the queue is full to prevent task overload
self._queue.put(None)
future = super().submit(fn, *args, **kwargs)

# When the task completes, remove a marker from the queue
def done_callback(_):
self._queue.get()

future.add_done_callback(done_callback)
return future


class PredictionWriter(pl.callbacks.BasePredictionWriter):
Expand All @@ -46,7 +88,18 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):
.. note::
To prevent an out-of-memory error, set the ``return_predictions`` argument of the
:class:`~lightning.pytorch.Trainer` to ``False``.
:class:`~lightning.pytorch.Trainer` to ``False``. This is accomplished in the config
file by including ``return_predictions: false`` at indent level 0. For example,
.. code-block:: yaml
trainer:
...
model:
...
data:
...
return_predictions: false
Args:
output_dir:
Expand All @@ -56,18 +109,33 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):
written. If not ``None``, only the first ``prediction_size`` columns will be written.
key:
PredictionWriter will write this key from the output of `predict()`.
gzip:
Whether to compress the CSV file using gzip.
max_threadpool_workers:
The maximum number of threads to use to write the predictions using a ThreadPoolExecutor.
"""

def __init__(
self,
output_dir: Path | str,
prediction_size: int | None = None,
key: str = "x_ng",
gzip: bool = True,
max_threadpool_workers: int = 8,
) -> None:
super().__init__(write_interval="batch")
self.output_dir = output_dir
self.prediction_size = prediction_size
self.key = key
self.executor = BoundedThreadPoolExecutor(
max_workers=max_threadpool_workers,
max_queue_size=max_threadpool_workers * 2,
)
self.gzip = gzip

def __del__(self):
"""Ensure the executor shuts down on object deletion."""
self.executor.shutdown(wait=True)

def write_on_batch_end(
self,
Expand Down Expand Up @@ -99,4 +167,6 @@ def write_on_batch_end(
obs_names_n=batch["obs_names_n"],
output_dir=self.output_dir,
postfix=batch_idx * trainer.world_size + trainer.global_rank,
gzip=self.gzip,
executor=self.executor,
)
13 changes: 13 additions & 0 deletions cellarium/ml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,19 @@ def _add_instantiators(self) -> None:
# https://github.com/Lightning-AI/pytorch-lightning/pull/18105
pass

def before_instantiate_classes(self):
# issue a UserWarning if the subcommand is predict and return_predictions is not set to False
if self.subcommand == "predict":
return_predictions: bool = self.config["predict"]["return_predictions"]
if return_predictions:
warnings.warn(
"The `return_predictions` argument should be set to 'false' when running predict to avoid OOM. "
"This can be set at indent level 0 in the config file. Example:\n"
"model: ...\ndata: ...\ntrainer: ...\nreturn_predictions: false",
UserWarning,
)
return super().before_instantiate_classes()

def instantiate_classes(self) -> None:
with torch.device("meta"):
# skip the initialization of model parameters
Expand Down
134 changes: 134 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import copy
import os
import warnings
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -631,3 +633,135 @@ def test_compute_var_names_g(tmp_path: Path) -> None:
with open(tmp_path / "onepass_config_with_cpu_filter.yaml", "w") as f:
f.write(onepass_config_with_cpu_filter)
main(["onepass_mean_var_std", "fit", "--config", str(tmp_path / "onepass_config_with_cpu_filter.yaml")])


def test_return_predictions_userwarning(tmp_path: Path):
"""Ensure a warning is emitted when return_predictions is set to true and the subcommand is predict."""

# using a pre-parsed config
config: dict[str, Any] = {
"model_name": "geneformer",
"subcommand": "predict",
"predict": {
"model": {
"model": {
"class_path": "cellarium.ml.models.Geneformer",
"init_args": {
"hidden_size": "2",
"num_hidden_layers": "1",
"num_attention_heads": "1",
"intermediate_size": "4",
"max_position_embeddings": "2",
},
},
},
"data": {
"dadc": {
"class_path": "cellarium.ml.data.DistributedAnnDataCollection",
"init_args": {
"filenames": "https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad",
"shard_size": "100",
"max_cache_size": "2",
"obs_columns_to_validate": [],
},
},
"batch_keys": {
"x_ng": {
"attr": "X",
"convert_fn": "cellarium.ml.utilities.data.densify",
},
"var_names_g": {"attr": "var_names"},
},
"batch_size": "5",
"num_workers": "1",
},
"trainer": {
"accelerator": "cpu",
"devices": devices,
"max_steps": "1",
"limit_predict_batches": "1",
},
"return_predictions": "true",
},
}

match_str = r"`return_predictions` argument should"

with pytest.warns(UserWarning, match=match_str):
main(copy.deepcopy(config)) # running main modifies the config dict

config["predict"]["return_predictions"] = "false"
with pytest.warns(UserWarning, match=match_str) as record:
warnings.warn("we need one warning: " + match_str, UserWarning)
main(copy.deepcopy(config))
n = 0
for r in record:
assert isinstance(r.message, Warning)
warning_message = r.message.args[0]
if match_str in warning_message:
n += 1
assert n < 2, "Unexpected UserWarning when running predict with return_predictions=false"

# using a config file
config_file_text = f"""
# lightning.pytorch==2.5.0.post0
seed_everything: true
trainer:
accelerator: cpu
devices: 1
max_steps: 1
limit_predict_batches: 1
default_root_dir: {tmp_path}
model:
cpu_transforms: null
transforms: null
model:
class_path: cellarium.ml.models.Geneformer
init_args:
hidden_size: 2
num_hidden_layers: 1
num_attention_heads: 1
intermediate_size: 4
max_position_embeddings: 2
data:
dadc:
class_path: cellarium.ml.data.DistributedAnnDataCollection
init_args:
filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad
shard_size: 100
max_cache_size: 2
obs_columns_to_validate: []
batch_keys:
x_ng:
attr: X
convert_fn: cellarium.ml.utilities.data.densify
var_names_g:
attr: var_names
batch_size: 5
num_workers: 1
return_predictions: RETURN_PREDICTIONS_VALUE
ckpt_path: null
"""

# there should be a warning with return_predictions=true
with open(config_file_path := tmp_path / "config.yaml", "w") as f:
f.write(config_file_text.replace("RETURN_PREDICTIONS_VALUE", "true"))

with pytest.warns(UserWarning, match=match_str):
main(["geneformer", "predict", "--config", str(config_file_path)])

# there should be no warning with return_predictions=false
with open(config_file_path, "w") as f:
f.write(config_file_text.replace("RETURN_PREDICTIONS_VALUE", "false"))

with pytest.warns(UserWarning, match=match_str) as record:
warnings.warn("we need one warning: " + match_str, UserWarning)
main(["geneformer", "predict", "--config", str(config_file_path)])
n = 0
for r in record:
assert isinstance(r.message, Warning)
warning_message = r.message.args[0]
if match_str in warning_message:
print(warning_message)
n += 1
assert n < 2, "Unexpected UserWarning when running predict with return_predictions=false"

0 comments on commit 8e7c817

Please sign in to comment.