Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile neighbors without graph breaks #305

Merged
merged 9 commits into from
Apr 15, 2024

Conversation

RaulPPelaez
Copy link
Collaborator

Pytorch introduced a new API to handle extensions, it is "documented" here: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

It makes it possible to write meta registrations for C++ extensions, which I could not make before. With a meta registration torch.compile is able to understand custom operations. A meta registration is an implementation of the operator for the "meta" device (akin to CPU or CUDA), in which tensors only have shapes and are refered to as FakeTensor.
It is used by pytorch to gather information about the input/output shapes of an operator for compilation purposes.

Makes this code possible:

    example_pos = 100 * torch.rand(
        50, 3, requires_grad=True, dtype=dtype, device=device
    )
    model = OptimizedDistance(
        return_vecs=True,
        loop=True,
        max_num_pairs=-50,
        include_transpose=True,
        resize_to_fit=False,
        check_errors=False,
    ).to(device)
    for _ in range(25):
        model(example_pos)
    edge_index, edge_vec, edge_distance = model(example_pos)
    model = torch.compile(
        model,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
    )
    edge_index, edge_vec, edge_distance = model(example_pos)

Prior to this PR torch.compile had to be instructed to exclude the nieghbor extension from the operation graph:

if int(torch.__version__.split(".")[0]) >= 2:
import torch._dynamo as dynamo
dynamo.disallow_in_graph(torch.ops.torchmdnet_extensions.get_neighbor_pairs)

So it could not be compiled with fullgraph=True.

The new API starts at version 2.2.1, which is not yet in conda-forge. I made it so that the current behavior is unchanged for versions prior to it.

Still compile is not able to handle code like this, in which a particular item from a tensor is accessed.

        if self.check_errors:
            assert (
                num_pairs[0] <= max_pairs
            ), f"Found num_pairs({num_pairs[0]}) > max_num_pairs({max_pairs})"

It can still be compiled, just not with fullgraph=True.
The general rule being "if you can capture it into a CUDA graph you can torch.compile it"

Comment on lines +165 to +168
static auto fwd =
torch::Dispatcher::singleton()
.findSchemaOrThrow("torchmdnet_extensions::get_neighbor_pairs_fwd", "")
.typed<decltype(forward_impl)>();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peastman, I think we could use this code here to allow any pytorch model to be used in OpenMM-Torch. Such as torch.compile, torch.jit.trace or even models not compatible with TorchScript.
The code above allows to call any pytorch extension from C++, regardless of where or when the extension was registered.
We could have a function that registers the user model as an Autograd extension python-side and simply sends TorchForce the name of the extension.
Serialization would not be possible with torch.save though, unless we use something like Pickle in those instances.
If we solve serialization, this would decouple OpenMM-Torch from TorchScript entirely.

@RaulPPelaez RaulPPelaez marked this pull request as ready for review March 14, 2024 12:06
@RaulPPelaez RaulPPelaez changed the title Compile neighbors torch.compile neighbors without graph breaks Mar 14, 2024
@RaulPPelaez RaulPPelaez merged commit 72d6e8e into torchmd:main Apr 15, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant