Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check_errors option #253

Merged
merged 4 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ These requirements correspond to a particular rotation of the system and reduced

.. note:: The box defined by the vectors :math:`\vec{a} = (L_x, 0, 0)`, :math:`\vec{b} = (0, L_y, 0)`, and :math:`\vec{c} = (0, 0, L_z)` correspond to a rectangular box. In this case, the input option in the :ref:`configuration file <configuration-file>` would be ``box-vecs: [[L_x, 0, 0], [0, L_y, 0], [0, 0, L_z]]``.


CUDA Graphs
============

TensorNet is capturable into a `CUDA graph <https://developer.nvidia.com/blog/cuda-graphs/>`_ with the right options. This can dramatically increase performance during inference. The dynamically-shaped nature of training makes CUDA graphs not an option in most practical cases.

For TensorNet to be CUDA-graph compatible, `check_errors` must be `False` and `static_shapes` must be `True`. Manually capturing a piece of code can be challenging, instead, to take advantage of CUDA graphs you can use :py:mod:`torchmdnet.calculators.External`, which helps integrating a Torchmd-NET model into another code, or `OpenMM-Torch <https://github.com/openmm/openmm-torch>`_ if you are using OpenMM.



Multi-Node Training
===================

Expand All @@ -85,6 +95,7 @@ In order to train models on multiple nodes some environment variables have to be
- Due to the way PyTorch Lightning calculates the number of required DDP processes, all nodes must use the same number of GPUs. Otherwise training will not start or crash.
- We observe a 50x decrease in performance when mixing nodes with different GPU architectures (tested with RTX 2080 Ti and RTX 3090).


Developer Guide
---------------

Expand Down
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def test_cuda_graph_compatible(model_name):
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_error": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
Expand Down
36 changes: 28 additions & 8 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def create_model(args, prior_model=None, mean=None, std=None):
dtype = dtype_mapping[args["precision"]]
if "box_vecs" not in args:
args["box_vecs"] = None
if "check_errors" not in args:
args["check_errors"] = True
if "static_shapes" not in args:
args["static_shapes"] = False
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
Expand All @@ -42,8 +46,11 @@ def create_model(args, prior_model=None, mean=None, std=None):
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
check_errors=args["check_errors"],
max_num_neighbors=args["max_num_neighbors"],
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype) if args["box_vecs"] is not None else None,
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
else None,
dtype=dtype,
)

Expand Down Expand Up @@ -87,6 +94,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
is_equivariant = False
representation_model = TensorNet(
equivariance_invariance_group=args["equivariance_invariance_group"],
static_shapes=args["static_shapes"],
**shared_args,
)
else:
Expand Down Expand Up @@ -153,11 +161,21 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
# The following are for backward compatibility with models created when atomref was
# the only supported prior.
if "prior_model.initial_atomref" in state_dict:
warnings.warn(
"prior_model.initial_atomref is deprecated and will be removed in a future version. Use prior_model.0.initial_atomref instead.",
category=DeprecationWarning,
stacklevel=2,
)
state_dict["prior_model.0.initial_atomref"] = state_dict[
"prior_model.initial_atomref"
]
del state_dict["prior_model.initial_atomref"]
if "prior_model.atomref.weight" in state_dict:
warnings.warn(
"prior_model.atomref.weight is deprecated and will be removed in a future version. Use prior_model.0.atomref.weight instead.",
category=DeprecationWarning,
stacklevel=2,
)
state_dict["prior_model.0.atomref.weight"] = state_dict[
"prior_model.atomref.weight"
]
Expand Down Expand Up @@ -201,7 +219,7 @@ def create_prior_models(args, dataset=None):


class TorchMD_Net(nn.Module):
""" The main TorchMD-Net model.
"""The main TorchMD-Net model.

The TorchMD_Net class combines a given representation model (such as the equivariant transformer),
an output model (such as the scalar output module), and a prior model (such as the atomref prior).
Expand Down Expand Up @@ -311,15 +329,15 @@ def forward(

Args:
z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
box (Tensor, optional): Box vectors. Shape (3, 3).
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.

Returns:
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
Expand All @@ -330,7 +348,9 @@ def forward(
if self.derivative:
pos.requires_grad_(True)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(z, pos, batch, box=box, q=q, s=s)
x, v, z, pos, batch = self.representation_model(
z, pos, batch, box=box, q=q, s=s
)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)

Expand Down
7 changes: 5 additions & 2 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ class TensorNet(nn.Module):
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
static_shapes (bool, optional): Whether to enforce static shapes.
Makes the model CUDA-graph compatible.
Makes the model CUDA-graph compatible if check_errors is set to False.
(default: :obj:`True`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
"""

Expand All @@ -134,6 +136,7 @@ def __init__(
max_z=128,
equivariance_invariance_group="O(3)",
static_shapes=True,
check_errors=True,
dtype=torch.float32,
box_vecs=None,
):
Expand Down Expand Up @@ -202,7 +205,7 @@ def __init__(
max_num_pairs=-max_num_neighbors,
return_vecs=True,
loop=True,
check_errors=False,
check_errors=check_errors,
resize_to_fit=not self.static_shapes,
box=box_vecs,
long_edge_index=True,
Expand Down
4 changes: 4 additions & 0 deletions torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class TorchMD_ET(nn.Module):
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

"""

Expand All @@ -94,6 +96,7 @@ def __init__(
cutoff_upper=5.0,
max_z=100,
max_num_neighbors=32,
check_errors=True,
box_vecs=None,
dtype=torch.float32,
):
Expand Down Expand Up @@ -140,6 +143,7 @@ def __init__(
loop=True,
box=box_vecs,
long_edge_index=True,
check_errors=check_errors,
)
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
Expand Down
4 changes: 4 additions & 0 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class TorchMD_GN(nn.Module):
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

"""

Expand All @@ -101,6 +103,7 @@ def __init__(
cutoff_upper=5.0,
max_z=100,
max_num_neighbors=32,
check_errors=True,
aggr="add",
dtype=torch.float32,
box_vecs=None,
Expand Down Expand Up @@ -144,6 +147,7 @@ def __init__(
max_num_pairs=-max_num_neighbors,
box=box_vecs,
long_edge_index=True,
check_errors=check_errors,
)

self.distance_expansion = rbf_class_mapping[rbf_type](
Expand Down
6 changes: 5 additions & 1 deletion torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class TorchMD_T(nn.Module):
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

"""

Expand All @@ -91,6 +93,7 @@ def __init__(
distance_influence="both",
cutoff_lower=0.0,
cutoff_upper=5.0,
check_errors=True,
max_z=100,
max_num_neighbors=32,
dtype=torch.float,
Expand Down Expand Up @@ -133,7 +136,8 @@ def __init__(
max_num_pairs=-max_num_neighbors,
loop=True,
box=box_vecs,
long_edge_index=True
long_edge_index=True,
check_errors=check_errors,
)

self.distance_expansion = rbf_class_mapping[rbf_type](
Expand Down
3 changes: 3 additions & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def get_argparse():
`a[1] = a[2] = b[2] = 0`;`a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff`;`a[0] >= 2*b[0]`;`a[0] >= 2*c[0]`;`b[1] >= 2*c[1]`;
These requirements correspond to a particular rotation of the system and reduced form of the vectors, as well as the requirement that the cutoff be no larger than half the box width.
Example: [[1,0,0],[0,1,0],[0,0,1]]""")
parser.add_argument('--static_shapes', type=bool, default=False, help='If true, TensorNet will use statically shaped tensors for the network, making it capturable into a CUDA graphs. In some situations static shapes can lead to a speedup, but it increases memory usage.')

# other args
parser.add_argument('--check_errors', type=bool, default=True, help='Will check if max_num_neighbors is not enough to contain all neighbors. This is incompatible with CUDA graphs.')
parser.add_argument('--derivative', default=False, type=bool, help='If true, take the derivative of the prediction w.r.t coordinates')
parser.add_argument('--cutoff-lower', type=float, default=0.0, help='Lower cutoff in model')
parser.add_argument('--cutoff-upper', type=float, default=5.0, help='Upper cutoff in model')
Expand Down
Loading