Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PredictionWriter: optional gzip, use ThreadPoolExecutor #286

Merged
merged 15 commits into from
Jan 23, 2025
Merged
76 changes: 73 additions & 3 deletions cellarium/ml/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import os
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from queue import Queue

import lightning.pytorch as pl
import numpy as np
Expand All @@ -16,6 +18,8 @@ def write_prediction(
obs_names_n: np.ndarray,
output_dir: Path | str,
postfix: int | str,
gzip: bool = True,
executor: ThreadPoolExecutor | None = None,
) -> None:
"""
Write prediction to a CSV file.
Expand All @@ -29,13 +33,51 @@ def write_prediction(
The directory to write the prediction to.
postfix:
A postfix to add to the CSV file name.
gzip:
Whether to compress the CSV file using gzip.
executor:
The executor used to write the prediction. If ``None``, no executor will be used.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame(prediction.cpu())
df.insert(0, "obs_names_n", obs_names_n)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv")
df.to_csv(output_path, header=False, index=False)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv" + (".gz" if gzip else ""))
to_csv_kwargs: dict[str, str | bool] = {"header": False, "index": False}
if gzip:
to_csv_kwargs |= {"compression": "gzip"}

def _write_csv(frame: pd.DataFrame, path: str) -> None:
frame.to_csv(path, **to_csv_kwargs)

if executor is None:
_write_csv(df, output_path)
else:
executor.submit(_write_csv, df, output_path)


class BoundedThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor with a bounded queue for task submissions.
This class is used to prevent the queue from growing indefinitely when tasks are submitted,
which can lead to an out-of-memory error.
"""

def __init__(self, max_workers: int, max_queue_size: int):
# Use a bounded queue for task submissions
self._queue: Queue = Queue(max_queue_size)
super().__init__(max_workers=max_workers)

def submit(self, fn, /, *args, **kwargs):
# Block if the queue is full to prevent task overload
self._queue.put(None)
future = super().submit(fn, *args, **kwargs)

# When the task completes, remove a marker from the queue
def done_callback(_):
self._queue.get()

future.add_done_callback(done_callback)
return future


class PredictionWriter(pl.callbacks.BasePredictionWriter):
Expand All @@ -46,7 +88,18 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):

.. note::
To prevent an out-of-memory error, set the ``return_predictions`` argument of the
:class:`~lightning.pytorch.Trainer` to ``False``.
:class:`~lightning.pytorch.Trainer` to ``False``. This is accomplished in the config
file by including ``return_predictions: false`` at indent level 0. For example,

.. code-block:: yaml

trainer:
...
model:
...
data:
...
return_predictions: false

Args:
output_dir:
Expand All @@ -56,18 +109,33 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):
written. If not ``None``, only the first ``prediction_size`` columns will be written.
key:
PredictionWriter will write this key from the output of `predict()`.
gzip:
Whether to compress the CSV file using gzip.
max_threadpool_workers:
The maximum number of threads to use to write the predictions using a ThreadPoolExecutor.
"""

def __init__(
self,
output_dir: Path | str,
prediction_size: int | None = None,
key: str = "x_ng",
gzip: bool = True,
max_threadpool_workers: int = 8,
) -> None:
super().__init__(write_interval="batch")
self.output_dir = output_dir
self.prediction_size = prediction_size
self.key = key
self.executor = BoundedThreadPoolExecutor(
max_workers=max_threadpool_workers,
max_queue_size=max_threadpool_workers * 2,
)
self.gzip = gzip

def __del__(self):
"""Ensure the executor shuts down on object deletion."""
self.executor.shutdown(wait=True)

def write_on_batch_end(
self,
Expand Down Expand Up @@ -99,4 +167,6 @@ def write_on_batch_end(
obs_names_n=batch["obs_names_n"],
output_dir=self.output_dir,
postfix=batch_idx * trainer.world_size + trainer.global_rank,
gzip=self.gzip,
executor=self.executor,
)
13 changes: 13 additions & 0 deletions cellarium/ml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,19 @@ def _add_instantiators(self) -> None:
# https://github.com/Lightning-AI/pytorch-lightning/pull/18105
pass

def before_instantiate_classes(self):
# issue a UserWarning if the subcommand is predict and return_predictions is not set to False
if self.subcommand == "predict":
return_predictions: bool = self.config["predict"]["return_predictions"]
if return_predictions:
warnings.warn(
"The `return_predictions` argument should be set to 'false' when running predict to avoid OOM. "
"This can be set at indent level 0 in the config file. Example:\n"
"model: ...\ndata: ...\ntrainer: ...\nreturn_predictions: false",
UserWarning,
)
return super().before_instantiate_classes()

def instantiate_classes(self) -> None:
with torch.device("meta"):
# skip the initialization of model parameters
Expand Down
134 changes: 134 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import copy
import os
import warnings
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -631,3 +633,135 @@ def test_compute_var_names_g(tmp_path: Path) -> None:
with open(tmp_path / "onepass_config_with_cpu_filter.yaml", "w") as f:
f.write(onepass_config_with_cpu_filter)
main(["onepass_mean_var_std", "fit", "--config", str(tmp_path / "onepass_config_with_cpu_filter.yaml")])


def test_return_predictions_userwarning(tmp_path: Path):
"""Ensure a warning is emitted when return_predictions is set to true and the subcommand is predict."""

# using a pre-parsed config
config: dict[str, Any] = {
"model_name": "geneformer",
"subcommand": "predict",
"predict": {
"model": {
"model": {
"class_path": "cellarium.ml.models.Geneformer",
"init_args": {
"hidden_size": "2",
"num_hidden_layers": "1",
"num_attention_heads": "1",
"intermediate_size": "4",
"max_position_embeddings": "2",
},
},
},
"data": {
"dadc": {
"class_path": "cellarium.ml.data.DistributedAnnDataCollection",
"init_args": {
"filenames": "https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad",
"shard_size": "100",
"max_cache_size": "2",
"obs_columns_to_validate": [],
},
},
"batch_keys": {
"x_ng": {
"attr": "X",
"convert_fn": "cellarium.ml.utilities.data.densify",
},
"var_names_g": {"attr": "var_names"},
},
"batch_size": "5",
"num_workers": "1",
},
"trainer": {
"accelerator": "cpu",
"devices": devices,
"max_steps": "1",
"limit_predict_batches": "1",
},
"return_predictions": "true",
},
}

match_str = r"`return_predictions` argument should"

with pytest.warns(UserWarning, match=match_str):
main(copy.deepcopy(config)) # running main modifies the config dict

config["predict"]["return_predictions"] = "false"
with pytest.warns(UserWarning, match=match_str) as record:
warnings.warn("we need one warning: " + match_str, UserWarning)
main(copy.deepcopy(config))
n = 0
for r in record:
assert isinstance(r.message, Warning)
warning_message = r.message.args[0]
if match_str in warning_message:
n += 1
assert n < 2, "Unexpected UserWarning when running predict with return_predictions=false"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah so this is asserting that the UserWarning is not emitted if running prediction with return_predictions: false.

I might be doing it in a weird way, but I'm not sure what the right way is. It's easy to assert that a warning is emitted, but not so easy to test that a warning is not emitted (from what I can tell). The only way I could figure was to count up the warnings matching a certain match string. (And I needed at least one such warning, or the counting mechanism would not work. Thus the assertion n < 2... there is one "fake" warning to enable counting, and then any further warning would be the real warning.)


# using a config file
config_file_text = f"""
# lightning.pytorch==2.5.0.post0
seed_everything: true
trainer:
accelerator: cpu
devices: 1
max_steps: 1
limit_predict_batches: 1
default_root_dir: {tmp_path}
model:
cpu_transforms: null
transforms: null
model:
class_path: cellarium.ml.models.Geneformer
init_args:
hidden_size: 2
num_hidden_layers: 1
num_attention_heads: 1
intermediate_size: 4
max_position_embeddings: 2
data:
dadc:
class_path: cellarium.ml.data.DistributedAnnDataCollection
init_args:
filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad
shard_size: 100
max_cache_size: 2
obs_columns_to_validate: []
batch_keys:
x_ng:
attr: X
convert_fn: cellarium.ml.utilities.data.densify
var_names_g:
attr: var_names
batch_size: 5
num_workers: 1
return_predictions: RETURN_PREDICTIONS_VALUE
ckpt_path: null
"""

# there should be a warning with return_predictions=true
with open(config_file_path := tmp_path / "config.yaml", "w") as f:
f.write(config_file_text.replace("RETURN_PREDICTIONS_VALUE", "true"))

with pytest.warns(UserWarning, match=match_str):
main(["geneformer", "predict", "--config", str(config_file_path)])

# there should be no warning with return_predictions=false
with open(config_file_path, "w") as f:
f.write(config_file_text.replace("RETURN_PREDICTIONS_VALUE", "false"))

with pytest.warns(UserWarning, match=match_str) as record:
warnings.warn("we need one warning: " + match_str, UserWarning)
main(["geneformer", "predict", "--config", str(config_file_path)])
n = 0
for r in record:
assert isinstance(r.message, Warning)
warning_message = r.message.args[0]
if match_str in warning_message:
print(warning_message)
n += 1
assert n < 2, "Unexpected UserWarning when running predict with return_predictions=false"
Loading