Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PredictionWriter: optional gzip, use ThreadPoolExecutor #286

Merged
merged 15 commits into from
Jan 23, 2025

Conversation

sjfleming
Copy link
Contributor

@sjfleming sjfleming commented Jan 17, 2025

Closes #285

These changes add two init args to PredictionWriter: gzip (bool) and max_threadpool_workers (int), each of which have default values.

PredictionWriter now gzips saved CSVs by default, and runs the saving-and-gzip process in a background thread that does not block further lightning compute.

NewLightningCLI in cli.py also now injects return_predictions=False into calls to trainer.predict().

Testing indicates the following outcomes for scvi reconstructions which involve computing a dense output CSV with 250 columns:

               |    before changes     |     after changes
---------------------------------------------------------------
file size      |            12 MB      |        5 MB
total time     |            18 hr      |        5 hr

(the 18 hrs is gzipping the CSVs without the ThreadPool. if you don't gzip, the total time would be 7.5 hours.)

@sjfleming sjfleming changed the title PredictWriter: optional gzip, use ThreadPoolExecutor PredictionWriter: optional gzip, use ThreadPoolExecutor Jan 17, 2025
@sjfleming
Copy link
Contributor Author

sjfleming commented Jan 17, 2025

Here begins my stream-of-consciousness during development:

The only thing undesirable about this is that the ThreadPoolExecutor can fall behind. In my current test run for example, the prediction writing seems to be about 100 batches behind the actual lightning compute. There might be a danger of OOM if the max_threadpool_workers is so low that you're really lagging behind in terms of writing the outputs.

@sjfleming
Copy link
Contributor Author

Empirically it seems that 8 thread workers does not fall behind systematically, so for now I will change the default to 8

@sjfleming sjfleming requested a review from ordabayevy January 17, 2025 17:54
@sjfleming sjfleming marked this pull request as draft January 17, 2025 19:25
@sjfleming
Copy link
Contributor Author

sjfleming commented Jan 17, 2025

Ah dang it, it was killed with 8 workers due to OOM after 2.5 hours...
that may be because I pushed batch size too high, but perhaps this approach is a bit fragile

Interested in your thoughts @ordabayevy

@sjfleming
Copy link
Contributor Author

sjfleming commented Jan 17, 2025

Okay I have now implemented a BoundedThreadPoolExecutor which prevents the executor's queue from growing without limit (and thus memory usage growing without limit). Let's see if this works.

Projected total runtime looks like about 5.5 hours now instead of 5 (projection with unbounded queue).

@sjfleming
Copy link
Contributor Author

Added an (untested) check to fail fast if it can be projected that the total size of the prediction output files will not fit in the allocated disk space. (I ran into this problem and it was only after several hours I found out. :( )

@sjfleming
Copy link
Contributor Author

sjfleming commented Jan 18, 2025

The OOM problem was delayed but never completely disappeared. I now think the issue has to do with the following...

Needed to reach into cli.py to implement the note in the docstring here:

.. note::
To prevent an out-of-memory error, set the ``return_predictions`` argument of the
:class:`~lightning.pytorch.Trainer` to ``False``.

i.e. the return_predictions=False kwarg needs to be passed to trainer.predict(). I think the changes to cli.py are the correct way to make this happen.

@sjfleming
Copy link
Contributor Author

The above cli.py modification did fix the full scvi run.

@sjfleming sjfleming marked this pull request as ready for review January 18, 2025 18:43
cellarium/ml/cli.py Outdated Show resolved Hide resolved
@sjfleming
Copy link
Contributor Author

I can get rid of the changes to cli.py certainly, and leave return_predictions: false to the config file if you think that makes the most sense.

The only thing that bothers me about that solution is that somebody like me can come in, not know they have to include that in the config file or where it goes, and waste a bunch of time hitting out-of-memory errors :)

What I like about modifying cli.py in the manner above is that return_predictions becomes False whenever predict is called, regardless of the config file. So it prevents users from making mistakes. My thinking is just that the simpler the config files can be, the better.

But I can see both sides... and I guess it would be okay with me if we left things as is and just put in some kind of massive UserWarning or something to check and make sure someone hasn't forgotten to add return_predictions: false to their config file by mistake.

What do you think @ordabayevy ?

@ordabayevy
Copy link
Contributor

I prefer using the config file instead of hard coding and do the following:

  1. Improve the documentation of PredictionWriter and add an example of how return_predictions should be added to the config file.
  2. If you think that is not enough, add a warning message if return_predictions=True.

... but if we never really gonna need return_predictions=True maybe it makes sense to hard code it, idk. What happens if both your changes applied and return_predictions=True is set?

@sjfleming
Copy link
Contributor Author

Okay I've got those changes implemented. Got rid of the hard-coded return_predictions=False. Got rid of fail-fast if predictions won't fit on disk (thinking ahead to #290 ). Issues a UserWarning if running predict with return_predictions=True. Included explicit config file example in the docstring for PredictionWriter.

"This can be set at indent level 0 in the config file. Example:\n"
"model: ...\ndata: ...\ntrainer: ...\nreturn_predictions: false",
UserWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might not work if the config file used. Config file is parsed later on in the init method of LightningCLI. Probably this logic needs to be moved somewhere there. It should happen after this line self.parse_arguments(self.parser, args). This hook https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/pytorch/cli.py#L554 might be a good place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right you are! Thank you.

I have now included an explicit test to make sure I'm actually doing what I wanted. Indeed you're right... it didn't work with config files. I've tried to implement what you suggested and the tests seem to pass.

Copy link
Contributor

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

warning_message = r.message.args[0]
if match_str in warning_message:
n += 1
assert n < 2, "Unexpected UserWarning when running predict with return_predictions=false"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah so this is asserting that the UserWarning is not emitted if running prediction with return_predictions: false.

I might be doing it in a weird way, but I'm not sure what the right way is. It's easy to assert that a warning is emitted, but not so easy to test that a warning is not emitted (from what I can tell). The only way I could figure was to count up the warnings matching a certain match string. (And I needed at least one such warning, or the counting mechanism would not work. Thus the assertion n < 2... there is one "fake" warning to enable counting, and then any further warning would be the real warning.)

@ordabayevy ordabayevy merged commit 8e7c817 into main Jan 23, 2025
8 checks passed
@ordabayevy ordabayevy deleted the sf-predictwriter-gzip-threadpoolexecutor branch January 23, 2025 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PredictionWriter: gzip csvs and run in background thread
2 participants