torch.compile neighbors without graph breaks #305
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Pytorch introduced a new API to handle extensions, it is "documented" here: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit
It makes it possible to write meta registrations for C++ extensions, which I could not make before. With a meta registration torch.compile is able to understand custom operations. A meta registration is an implementation of the operator for the "meta" device (akin to CPU or CUDA), in which tensors only have shapes and are refered to as FakeTensor.
It is used by pytorch to gather information about the input/output shapes of an operator for compilation purposes.
Makes this code possible:
Prior to this PR torch.compile had to be instructed to exclude the nieghbor extension from the operation graph:
torchmd-net/torchmdnet/extensions/__init__.py
Lines 116 to 118 in fae79bd
So it could not be compiled with
fullgraph=True
.The new API starts at version 2.2.1, which is not yet in conda-forge. I made it so that the current behavior is unchanged for versions prior to it.
Still compile is not able to handle code like this, in which a particular item from a tensor is accessed.
It can still be compiled, just not with fullgraph=True.
The general rule being "if you can capture it into a CUDA graph you can torch.compile it"