From 4500e60bbceaacf44b37b7eb86d54bb9eeeb1158 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 25 Oct 2023 12:35:08 +0200 Subject: [PATCH 1/7] Update README with conda-forge installation --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e5e38fa4a..0dbe4d28f 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,14 @@ TorchMD-NET provides state-of-the-art neural networks potentials (NNPs) and a me - [TensorNet](https://arxiv.org/abs/2306.06482) -## Installation +## Installation +TorchMD-Net is available in [conda-forge](https://conda-forge.org/) and can be installed with: +```shell +mamba install torchmd-net +``` +We recommend using [Mamba](https://github.com/conda-forge/miniforge/#mambaforge) instead of conda. + +### Install from source 1. Clone the repository: ```shell @@ -21,7 +28,7 @@ TorchMD-NET provides state-of-the-art neural networks potentials (NNPs) and a me cd torchmd-net ``` -2. Install [Mambaforge](https://github.com/conda-forge/miniforge/#mambaforge). We recommend to use `mamba` rather than `conda`. +2. Install the dependencies in environment.yml. You can do it via pip, but we recommend [Mambaforge](https://github.com/conda-forge/miniforge/#mambaforge) instead. 3. Create an environment and activate it: ```shell From 9eada2092c4bfe3642be29292ad4be1e9fafb80f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 25 Oct 2023 12:49:44 +0200 Subject: [PATCH 2/7] Add note about CUDA --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index 0dbe4d28f..ed9512f30 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,23 @@ We recommend using [Mamba](https://github.com/conda-forge/miniforge/#mambaforge) This will install TorchMD-NET in editable mode, so that changes to the source code are immediately available. Besides making all python utilities available environment-wide, this will also install the `torchmd-train` command line utility. + +#### CUDA enabled installation + +Besides the dependencies listed in the environment file, you will also need the CUDA `nvcc` compiler suite to build TorchMD-Net. +If your system lacks nvcc you may install it via conda-forge: + ```shell + mamba install cudatoolkit-dev + ``` +Or from the nvidia channel: +```shell +mamba install -c nvidia cuda-nvcc cuda-cudart-dev cuda-libraries-dev +``` +Make sure you install a major version compatible with your torch installation, which you can check with: +```shell +python -c "import torch; print(torch.version.cuda)" +``` + ## Usage Specifying training arguments can either be done via a configuration yaml file or through command line arguments directly. Several examples of architectural and training specifications for some models and datasets can be found in [examples/](https://github.com/torchmd/torchmd-net/tree/main/examples). Note that if a parameter is present both in the yaml file and the command line, the command line version takes precedence. GPUs can be selected by setting the `CUDA_VISIBLE_DEVICES` environment variable. Otherwise, the argument `--ngpus` can be used to select the number of GPUs to train on (-1, the default, uses all available GPUs or the ones specified in `CUDA_VISIBLE_DEVICES`). Keep in mind that the [GPU ID reported by nvidia-smi might not be the same as the one `CUDA_VISIBLE_DEVICES` uses](https://stackoverflow.com/questions/26123252/inconsistency-of-ids-between-nvidia-smi-l-and-cudevicegetname). From af51c589ed5f86703b1f1299d38f9999a8ddd569 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:25:33 +0100 Subject: [PATCH 3/7] Add support for charged molecules in TensorNet (total charge) (#238) * Update tensornet.py for support of total charge q * Update tensornet.py for q support * fix * fix * fix comment * initialize zero charge as tensor, move charge broadcasting to interaction module * add clarification comment * trying fix * try fix --- torchmdnet/models/tensornet.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 64f9b9d36..5c3158e80 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -212,10 +212,16 @@ def forward( edge_vec is not None ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q + if q is None: + q = torch.zeros_like(z, device=z.device, dtype=z.dtype) + else: + q = q[batch] zp = z if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) + q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -228,7 +234,7 @@ def forward( edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr) + X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -379,7 +385,7 @@ def reset_parameters(self): linear.reset_parameters() def forward( - self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor + self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, q: Tensor ) -> Tensor: C = self.cutoff(edge_weight) @@ -401,7 +407,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(A + B) + I, A, S = decompose_tensor((1 + 0.1*q[...,None,None,None])*(A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) @@ -411,5 +417,5 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + torch.matrix_power(dX, 2) + X = X + dX + (1 + 0.1*q[...,None,None,None]) * torch.matrix_power(dX, 2) return X From a67ac80299b153390737d8bf8e16670aaf234139 Mon Sep 17 00:00:00 2001 From: Raul Date: Fri, 3 Nov 2023 13:05:06 +0100 Subject: [PATCH 4/7] Remove dependency on torch_scatter and torch_sparse (#229) * Remove torch_scatter dependency in tensornet * Remove dependency on torch_scatter * Fix import in test * Remove torch_scatter dependency Manually implement message passing in all architectures, removing dependency on pytorch_geometric * Remove spurious comment * Update environment.yaml * Reintroduce cluster * Fix import * Fix import * Merge remote-tracking branch 'origin/main' into remove_scatter_sparse * Remove import * Remove import * Blacken --- environment.yml | 2 - tests/test_cfconv.py | 2 +- tests/test_model.py | 6 +- tests/test_priors.py | 2 +- torchmdnet/data.py | 2 +- torchmdnet/models/model.py | 1 - torchmdnet/models/output_modules.py | 58 ++++++++----- torchmdnet/models/tensornet.py | 109 ++++++++++++++---------- torchmdnet/models/torchmd_et.py | 60 +++++++++---- torchmdnet/models/torchmd_gn.py | 59 ++++++++----- torchmdnet/models/torchmd_t.py | 48 +++++++---- torchmdnet/models/utils.py | 125 +++++++++++++++++++++++----- torchmdnet/priors/coulomb.py | 3 +- torchmdnet/priors/d2.py | 13 ++- torchmdnet/priors/zbl.py | 60 +++++++++---- 15 files changed, 380 insertions(+), 170 deletions(-) diff --git a/environment.yml b/environment.yml index 5814ed3ba..4dec17b9f 100644 --- a/environment.yml +++ b/environment.yml @@ -8,8 +8,6 @@ dependencies: - pip - pytorch==2.0.* - pytorch_geometric==2.3.1 - - pytorch_scatter==2.1.1 - - pytorch_sparse==0.6.17 - lightning==2.0.8 - pydantic<2 - torchmetrics==0.11.4 diff --git a/tests/test_cfconv.py b/tests/test_cfconv.py index 80c12bc03..0081afa3e 100644 --- a/tests/test_cfconv.py +++ b/tests/test_cfconv.py @@ -56,7 +56,7 @@ def test_cfconv(device, num_atoms, num_filters, num_rbfs, cutoff_upper): # Compute with the non-optimized CFConv edge_index, edge_weight, _ = dist(pos, batch=None) edge_attr = rbf(edge_weight) - ref_output = ref_conv(input, edge_index, edge_weight, edge_attr) + ref_output = ref_conv(input, edge_index, edge_weight, edge_attr, pos.shape[0]) ref_total = pt.sum(ref_output) ref_total.backward() ref_grad = pos.grad.clone() diff --git a/tests/test_model.py b/tests/test_model.py index 209339b65..ed3a409d4 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -181,10 +181,10 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference - torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"]) + torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5) if derivative: - torch.testing.assert_allclose( - deriv, expected[model_name][output_model]["deriv"] + torch.testing.assert_close( + deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5 ) diff --git a/tests/test_priors.py b/tests/test_priors.py index 47b838381..260391c82 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -6,7 +6,7 @@ from torchmdnet.models.model import create_model, create_prior_models from torchmdnet.module import LNNP from torchmdnet.priors import Atomref, D2, ZBL, Coulomb -from torch_scatter import scatter +from torchmdnet.models.utils import scatter from utils import load_example_args, create_example_batch, DummyDataset from os.path import dirname, join import tempfile diff --git a/torchmdnet/data.py b/torchmdnet/data.py index 702df4f46..19c812f89 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -8,7 +8,7 @@ from torchmdnet import datasets from torch_geometric.data import Dataset from torchmdnet.utils import make_splits, MissingEnergyException -from torch_scatter import scatter +from torchmdnet.models.utils import scatter from torchmdnet.models.utils import dtype_mapping diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 16937dd80..7ee44e361 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -3,7 +3,6 @@ import torch from torch.autograd import grad from torch import nn, Tensor -from torch_scatter import scatter from torchmdnet.models import output_modules from torchmdnet.models.wrappers import AtomFilter from torchmdnet.models.utils import dtype_mapping diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index f60c9baf4..2283ef279 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -1,13 +1,12 @@ from abc import abstractmethod, ABCMeta -from torch_scatter import scatter from typing import Optional -from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock -from torchmdnet.utils import atomic_masses -from torchmdnet.extensions import is_current_stream_capturing -from torch_scatter import scatter import torch from torch import nn +from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock, scatter +from torchmdnet.utils import atomic_masses +from torchmdnet.extensions import is_current_stream_capturing from warnings import warn + __all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent"] @@ -26,10 +25,7 @@ def pre_reduce(self, x, v, z, pos, batch): return def reduce(self, x, batch): - is_capturing = ( - x.is_cuda - and is_current_stream_capturing() - ) + is_capturing = x.is_cuda and is_current_stream_capturing() if not x.is_cuda or not is_capturing: self.dim_size = int(batch.max().item() + 1) if is_capturing: @@ -54,7 +50,7 @@ def __init__( activation="silu", allow_prior_model=True, reduce_op="sum", - dtype=torch.float + dtype=torch.float, ): super(Scalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op @@ -85,7 +81,7 @@ def __init__( activation="silu", allow_prior_model=True, reduce_op="sum", - dtype=torch.float + dtype=torch.float, ): super(EquivariantScalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op @@ -97,9 +93,11 @@ def __init__( hidden_channels // 2, activation=activation, scalar_activation=True, - dtype=dtype + dtype=dtype, + ), + GatedEquivariantBlock( + hidden_channels // 2, 1, activation=activation, dtype=dtype ), - GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation, dtype=dtype), ] ) @@ -117,9 +115,15 @@ def pre_reduce(self, x, v, z, pos, batch): class DipoleMoment(Scalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): + def __init__( + self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + ): super(DipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype + hidden_channels, + activation, + allow_prior_model=False, + reduce_op=reduce_op, + dtype=dtype, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -138,9 +142,15 @@ def post_reduce(self, x): class EquivariantDipoleMoment(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): + def __init__( + self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + ): super(EquivariantDipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype + hidden_channels, + activation, + allow_prior_model=False, + reduce_op=reduce_op, + dtype=dtype, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -160,7 +170,9 @@ def post_reduce(self, x): class ElectronicSpatialExtent(OutputModel): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): + def __init__( + self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + ): super(ElectronicSpatialExtent, self).__init__( allow_prior_model=False, reduce_op=reduce_op ) @@ -197,9 +209,15 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): class EquivariantVectorOutput(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): + def __init__( + self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + ): super(EquivariantVectorOutput, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op="sum", dtype=dtype + hidden_channels, + activation, + allow_prior_model=False, + reduce_op="sum", + dtype=dtype, ) def pre_reduce(self, x, v, z, pos, batch): diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 5c3158e80..b277a8186 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -1,8 +1,6 @@ import torch -import numpy as np from typing import Optional, Tuple from torch import Tensor, nn -from torch_scatter import scatter from torchmdnet.models.utils import ( CosineCutoff, OptimizedDistance, @@ -10,7 +8,7 @@ act_class_mapping, ) -torch.set_float32_matmul_precision('high') +torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True # Creates a skew-symmetric tensor from a vector def vector_to_skewtensor(vector): @@ -54,14 +52,6 @@ def decompose_tensor(tensor): return I, A, S -# Modifies tensor by multiplying invariant features to irreducible components -def new_radial_tensor(I, A, S, f_I, f_A, f_S): - I = f_I[..., None, None] * I - A = f_A[..., None, None] * A - S = f_S[..., None, None] * S - return I, A, S - - # Computes Frobenius norm def tensor_norm(tensor): return (tensor**2).sum((-2, -1)) @@ -295,6 +285,35 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() + def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: + Z = self.emb(z) + Zij = self.emb2( + Z.index_select(0, edge_index.t().reshape(-1)).view( + -1, self.hidden_channels * 2 + ) + )[..., None, None] + return Zij + + def _get_tensor_messages( + self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij + eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[ + None, None, ... + ] + Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye + Aij = ( + self.distance_proj2(edge_attr)[..., None, None] + * C + * vector_to_skewtensor(edge_vec_norm)[..., None, :, :] + ) + Sij = ( + self.distance_proj3(edge_attr)[..., None, None] + * C + * vector_to_symtensor(edge_vec_norm)[..., None, :, :] + ) + return Iij, Aij, Sij + def forward( self, z: Tensor, @@ -303,43 +322,43 @@ def forward( edge_vec_norm: Tensor, edge_attr: Tensor, ) -> Tensor: - C = self.cutoff(edge_weight) - W1 = self.distance_proj1(edge_attr) * C.view(-1, 1) - W2 = self.distance_proj2(edge_attr) * C.view(-1, 1) - W3 = self.distance_proj3(edge_attr) * C.view(-1, 1) - Iij, Aij, Sij = new_radial_tensor( - torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[ - None, None, :, : - ], - vector_to_skewtensor(edge_vec_norm)[..., None, :, :], - vector_to_symtensor(edge_vec_norm)[..., None, :, :], - W1, - W2, - W3, + Zij = self._get_atomic_number_message(z, edge_index) + Iij, Aij, Sij = self._get_tensor_messages( + Zij, edge_weight, edge_vec_norm, edge_attr ) - Z = self.emb(z) - Zij = self.emb2( - Z.index_select(0, edge_index.t().reshape(-1)).view(-1, self.hidden_channels * 2) - )[..., None, None] - I = scatter(Zij*Iij, edge_index[0], dim=0, dim_size=z.shape[0]) - A = scatter(Zij*Aij, edge_index[0], dim=0, dim_size=z.shape[0]) - S = scatter(Zij*Sij, edge_index[0], dim=0, dim_size=z.shape[0]) + source = torch.zeros( + z.shape[0], self.hidden_channels, 3, 3, device=z.device, dtype=Iij.dtype + ) + I = source.index_add(dim=0, index=edge_index[0], source=Iij) + A = source.index_add(dim=0, index=edge_index[0], source=Aij) + S = source.index_add(dim=0, index=edge_index[0], source=Sij) norm = self.init_norm(tensor_norm(I + A + S)) - I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) for linear_scalar in self.linears_scalar: norm = self.act(linear_scalar(norm)) - norm = norm.reshape(norm.shape[0], self.hidden_channels, 3) - I, A, S = new_radial_tensor(I, A, S, norm[..., 0], norm[..., 1], norm[..., 2]) + norm = norm.reshape(-1, self.hidden_channels, 3) + I = ( + self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * norm[..., 0, None, None] + ) + A = ( + self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * norm[..., 1, None, None] + ) + S = ( + self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * norm[..., 2, None, None] + ) X = I + A + S - return X -def tensor_message_passing(edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int) -> Tensor: +def tensor_message_passing( + edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int +) -> Tensor: msg = factor * tensor.index_select(0, edge_index[1]) - tensor_m = scatter(msg, edge_index[0], dim=0, dim_size=natoms) + shape = (natoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]) + tensor_m = torch.zeros(*shape, device=tensor.device, dtype=tensor.dtype) + tensor_m = tensor_m.index_add(0, edge_index[0], msg) return tensor_m @@ -400,9 +419,15 @@ def forward( A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) Y = I + A + S - Im = tensor_message_passing(edge_index, edge_attr[..., 0, None, None], I, X.shape[0]) - Am = tensor_message_passing(edge_index, edge_attr[..., 1, None, None], A, X.shape[0]) - Sm = tensor_message_passing(edge_index, edge_attr[..., 2, None, None], S, X.shape[0]) + Im = tensor_message_passing( + edge_index, edge_attr[..., 0, None, None], I, X.shape[0] + ) + Am = tensor_message_passing( + edge_index, edge_attr[..., 1, None, None], A, X.shape[0] + ) + Sm = tensor_message_passing( + edge_index, edge_attr[..., 2, None, None], S, X.shape[0] + ) msg = Im + Am + Sm if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index a295d1302..8abca54ae 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -1,16 +1,16 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn -from torch_geometric.nn import MessagePassing -from torch_scatter import scatter from torchmdnet.models.utils import ( NeighborEmbedding, CosineCutoff, OptimizedDistance, rbf_class_mapping, act_class_mapping, + scatter, ) + class TorchMD_ET(nn.Module): r"""The TorchMD equivariant Transformer architecture. @@ -109,7 +109,7 @@ def __init__( max_num_pairs=-max_num_neighbors, return_vecs=True, loop=True, - long_edge_index=True + long_edge_index=True, ) self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -117,7 +117,7 @@ def __init__( self.neighbor_embedding = ( NeighborEmbedding( hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype - ).jittable() + ) if neighbor_embedding else None ) @@ -134,7 +134,7 @@ def __init__( cutoff_lower, cutoff_upper, dtype, - ).jittable() + ) self.attention_layers.append(layer) self.out_norm = nn.LayerNorm(hidden_channels, dtype=dtype) @@ -150,7 +150,6 @@ def reset_parameters(self): attn.reset_parameters() self.out_norm.reset_parameters() - def forward( self, z: Tensor, @@ -205,7 +204,7 @@ def __repr__(self): ) -class EquivariantMultiHeadAttention(MessagePassing): +class EquivariantMultiHeadAttention(nn.Module): def __init__( self, hidden_channels, @@ -218,7 +217,7 @@ def __init__( cutoff_upper, dtype=torch.float32, ): - super(EquivariantMultiHeadAttention, self).__init__(aggr="add", node_dim=0) + super(EquivariantMultiHeadAttention, self).__init__() assert hidden_channels % num_heads == 0, ( f"The number of hidden channels ({hidden_channels}) " f"must be evenly divisible by the number of " @@ -239,7 +238,9 @@ def __init__( self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3, dtype=dtype) self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3, dtype=dtype) - self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False, dtype=dtype) + self.vec_proj = nn.Linear( + hidden_channels, hidden_channels * 3, bias=False, dtype=dtype + ) self.dk_proj = None if distance_influence in ["keys", "both"]: @@ -289,8 +290,6 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): if self.dv_proj is not None else None ) - - # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, d_ij: Tensor) x, vec = self.propagate( edge_index, q=q, @@ -301,7 +300,7 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): dv=dv, r_ij=r_ij, d_ij=d_ij, - size=None, + dim_size=None, ) x = x.reshape(-1, self.hidden_channels) vec = vec.reshape(-1, 3, self.hidden_channels) @@ -311,7 +310,37 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): dvec = vec3 * o1.unsqueeze(1) + vec return dx, dvec - def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): + def propagate( + self, + edge_index: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + vec: Tensor, + dk: Optional[Tensor], + dv: Optional[Tensor], + r_ij: Tensor, + d_ij: Tensor, + dim_size: Optional[int], + ) -> Tuple[Tensor, Tensor]: + q_i = q.index_select(0, edge_index[1]) + k_j = k.index_select(0, edge_index[0]) + v_j = v.index_select(0, edge_index[0]) + vec_j = vec.index_select(0, edge_index[0]) + x, vec = self.message(q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij) + return self.aggregate((x, vec), edge_index[1], dim_size=dim_size) + + def message( + self, + q_i: Tensor, + k_j: Tensor, + v_j: Tensor, + vec_j: Tensor, + dk: Optional[Tensor], + dv: Optional[Tensor], + r_ij: Tensor, + d_ij: Tensor, + ) -> Tuple[Tensor, Tensor]: # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) @@ -338,12 +367,11 @@ def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor], index: torch.Tensor, - ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor]: x, vec = features - x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) - vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) + x = scatter(x, index, dim=0, dim_size=dim_size) + vec = scatter(vec, index, dim=0, dim_size=dim_size) return x, vec def update( diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 3b7f7ec9d..5d8c5a75d 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -8,6 +8,7 @@ OptimizedDistance, rbf_class_mapping, act_class_mapping, + scatter, ) @@ -72,7 +73,7 @@ def __init__( max_z=100, max_num_neighbors=32, aggr="add", - dtype=torch.float32 + dtype=torch.float32, ): super(TorchMD_GN, self).__init__() @@ -115,8 +116,13 @@ def __init__( ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype=dtype - ).jittable() + hidden_channels, + num_rbf, + cutoff_lower, + cutoff_upper, + self.max_z, + dtype=dtype, + ) if neighbor_embedding else None ) @@ -131,7 +137,7 @@ def __init__( cutoff_lower, cutoff_upper, aggr=self.aggr, - dtype=dtype + dtype=dtype, ) self.interactions.append(block) @@ -163,7 +169,9 @@ def forward( x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) for interaction in self.interactions: - x = x + interaction(x, edge_index, edge_weight, edge_attr) + x = x + interaction( + x, edge_index, edge_weight, edge_attr, n_atoms=z.shape[0] + ) return x, None, z, pos, batch @@ -194,7 +202,7 @@ def __init__( cutoff_lower, cutoff_upper, aggr="add", - dtype=torch.float32 + dtype=torch.float32, ): super(InteractionBlock, self).__init__() self.mlp = nn.Sequential( @@ -210,8 +218,8 @@ def __init__( cutoff_lower, cutoff_upper, aggr=aggr, - dtype=dtype - ).jittable() + dtype=dtype, + ) self.act = activation() self.lin = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) @@ -226,14 +234,21 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.lin.weight) self.lin.bias.data.fill_(0) - def forward(self, x, edge_index, edge_weight, edge_attr): - x = self.conv(x, edge_index, edge_weight, edge_attr) + def forward( + self, + x: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + edge_attr: Tensor, + n_atoms: Optional[int] = None, + ) -> Tensor: + x = self.conv(x, edge_index, edge_weight, edge_attr, n_atoms) x = self.act(x) x = self.lin(x) return x -class CFConv(MessagePassing): +class CFConv(nn.Module): def __init__( self, in_channels, @@ -243,14 +258,14 @@ def __init__( cutoff_lower, cutoff_upper, aggr="add", - dtype=torch.float32 + dtype=torch.float32, ): - super(CFConv, self).__init__(aggr=aggr) + super(CFConv, self).__init__() self.lin1 = nn.Linear(in_channels, num_filters, bias=False, dtype=dtype) self.lin2 = nn.Linear(num_filters, out_channels, bias=True, dtype=dtype) self.net = net self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) - + self.aggr = aggr self.reset_parameters() def reset_parameters(self): @@ -258,15 +273,19 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) - def forward(self, x, edge_index, edge_weight, edge_attr): + def forward( + self, + x: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + edge_attr: Tensor, + n_atoms: Optional[int] = None, + ) -> Tensor: C = self.cutoff(edge_weight) W = self.net(edge_attr) * C.view(-1, 1) x = self.lin1(x) - # propagate_type: (x: Tensor, W: Tensor) - x = self.propagate(edge_index, x=x, W=W, size=None) + msg = W * x.index_select(0, edge_index[1]) + x = scatter(msg, edge_index[0], dim=0, dim_size=n_atoms, reduce=self.aggr) x = self.lin2(x) return x - - def message(self, x_j, W): - return x_j * W diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 8c93e1317..71d4d8f18 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -1,15 +1,15 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn -from torch_geometric.nn import MessagePassing from torchmdnet.models.utils import ( NeighborEmbedding, CosineCutoff, OptimizedDistance, rbf_class_mapping, act_class_mapping, + scatter, ) -from torch_scatter import scatter + class TorchMD_T(nn.Module): r"""The TorchMD Transformer architecture. @@ -66,7 +66,7 @@ def __init__( cutoff_upper=5.0, max_z=100, max_num_neighbors=32, - dtype=torch.float + dtype=torch.float, ): super(TorchMD_T, self).__init__() @@ -107,8 +107,13 @@ def __init__( ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype=dtype - ).jittable() + hidden_channels, + num_rbf, + cutoff_lower, + cutoff_upper, + self.max_z, + dtype=dtype, + ) if neighbor_embedding else None ) @@ -125,7 +130,7 @@ def __init__( cutoff_lower, cutoff_upper, dtype=dtype, - ).jittable() + ) self.attention_layers.append(layer) self.out_norm = nn.LayerNorm(hidden_channels, dtype=dtype) @@ -159,7 +164,7 @@ def forward( x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) for attn in self.attention_layers: - x = x + attn(x, edge_index, edge_weight, edge_attr) + x = x + attn(x, edge_index, edge_weight, edge_attr, z.shape[0]) x = self.out_norm(x) return x, None, z, pos, batch @@ -182,7 +187,7 @@ def __repr__(self): ) -class MultiHeadAttention(MessagePassing): +class MultiHeadAttention(nn.Module): def __init__( self, hidden_channels, @@ -195,7 +200,7 @@ def __init__( cutoff_upper, dtype=torch.float, ): - super(MultiHeadAttention, self).__init__(aggr="add", node_dim=0) + super(MultiHeadAttention, self).__init__() assert hidden_channels % num_heads == 0, ( f"The number of hidden channels ({hidden_channels}) " f"must be evenly divisible by the number of " @@ -243,7 +248,9 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) - def forward(self, x, edge_index, r_ij, f_ij): + def forward( + self, x: Tensor, edge_index: Tensor, r_ij: Tensor, f_ij: Tensor, n_atoms: int + ) -> Tensor: head_shape = (-1, self.num_heads, self.head_dim) x = self.layernorm(x) @@ -261,15 +268,24 @@ def forward(self, x, edge_index, r_ij, f_ij): if self.dv_proj is not None else None ) - - # propagate_type: (q: Tensor, k: Tensor, v: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor) - out = self.propagate( - edge_index, q=q, k=k, v=v, dk=dk, dv=dv, r_ij=r_ij, size=None - ) + msg = self.message(edge_index, q, k, v, dk, dv, r_ij) + out = scatter(msg, edge_index[0], dim=0, dim_size=n_atoms) out = self.o_proj(out.reshape(-1, self.num_heads * self.head_dim)) return out - def message(self, q_i, k_j, v_j, dk, dv, r_ij): + def message( + self, + edge_index: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + dk: Optional[Tensor], + dv: Optional[Tensor], + r_ij: Tensor, + ) -> Tensor: + q_i = q.index_select(0, edge_index[0]) + k_j = k.index_select(0, edge_index[1]) + v_j = v.index_select(0, edge_index[1]) # compute attention matrix if dk is None: attn = (q_i * k_j).sum(dim=-1) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index c94313103..bbbea1bf2 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -1,10 +1,8 @@ import math from typing import Optional, Tuple import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor import torch.nn.functional as F -from torch_geometric.nn import MessagePassing from torchmdnet.extensions import get_neighbor_pairs_kernel import warnings @@ -40,8 +38,16 @@ def visualize_basis(basis_type, num_rbf=50, cutoff_lower=0, cutoff_upper=5): plt.show() -class NeighborEmbedding(MessagePassing): - def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=100, dtype=torch.float32): +class NeighborEmbedding(nn.Module): + def __init__( + self, + hidden_channels, + num_rbf, + cutoff_lower, + cutoff_upper, + max_z=100, + dtype=torch.float32, + ): """ The ET architecture assigns two learned vectors to each atom type zi. One is used to encode information specific to an atom, the @@ -55,7 +61,7 @@ def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=1 See eq. 3 in https://arxiv.org/pdf/2202.02541.pdf for more details. """ - super(NeighborEmbedding, self).__init__(aggr="add") + super(NeighborEmbedding, self).__init__() self.embedding = nn.Embedding(max_z, hidden_channels, dtype=dtype) self.distance_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.combine = nn.Linear(hidden_channels * 2, hidden_channels, dtype=dtype) @@ -77,7 +83,7 @@ def forward( edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - ): + ) -> Tensor: """ Args: z (Tensor): Atomic numbers of shape :obj:`[num_nodes]` @@ -99,13 +105,13 @@ def forward( W = self.distance_proj(edge_attr) * C.view(-1, 1) x_neighbors = self.embedding(z) - # propagate_type: (x: Tensor, W: Tensor) - x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W, size=None) + msg = W * x_neighbors.index_select(0, edge_index[1]) + x_neighbors = torch.zeros( + z.shape[0], x.shape[1], dtype=x.dtype, device=x.device + ).index_add(0, edge_index[0], msg) x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) return x_neighbors - def message(self, x_j, W): - return x_j * W class OptimizedDistance(torch.nn.Module): def __init__( @@ -120,7 +126,7 @@ def __init__( resize_to_fit=True, check_errors=True, box=None, - long_edge_index=True + long_edge_index=True, ): super(OptimizedDistance, self).__init__() """ Compute the neighbor list for a given cutoff. @@ -229,7 +235,7 @@ def forward( """ self.box = self.box.to(pos.dtype) - max_pairs : int = self.max_num_pairs + max_pairs: int = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs * pos.shape[0] if batch is None: @@ -268,7 +274,14 @@ def forward( class GaussianSmearing(nn.Module): - def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32): + def __init__( + self, + cutoff_lower=0.0, + cutoff_upper=5.0, + num_rbf=50, + trainable=True, + dtype=torch.float32, + ): super(GaussianSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper @@ -284,7 +297,9 @@ def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=Tru self.register_buffer("offset", offset) def _initial_params(self): - offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf, dtype=self.dtype) + offset = torch.linspace( + self.cutoff_lower, self.cutoff_upper, self.num_rbf, dtype=self.dtype + ) coeff = -0.5 / (offset[1] - offset[0]) ** 2 return offset, coeff @@ -293,13 +308,20 @@ def reset_parameters(self): self.offset.data.copy_(offset) self.coeff.data.copy_(coeff) - def forward(self, dist): + def forward(self, dist: Tensor) -> Tensor: dist = dist.unsqueeze(-1) - self.offset return torch.exp(self.coeff * torch.pow(dist, 2)) class ExpNormalSmearing(nn.Module): - def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32): + def __init__( + self, + cutoff_lower=0.0, + cutoff_upper=5.0, + num_rbf=50, + trainable=True, + dtype=torch.float32, + ): super(ExpNormalSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper @@ -321,11 +343,14 @@ def _initial_params(self): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 start_value = torch.exp( - torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype) + torch.scalar_tensor( + -self.cutoff_upper + self.cutoff_lower, dtype=self.dtype + ) ) means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype) betas = torch.tensor( - [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf, dtype=self.dtype + [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf, + dtype=self.dtype, ) return means, betas @@ -349,6 +374,7 @@ class ShiftedSoftplus(nn.Module): SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. """ + def __init__(self): super(ShiftedSoftplus, self).__init__() self.shift = torch.log(torch.tensor(2.0)).item() @@ -387,6 +413,7 @@ def forward(self, distances: Tensor) -> Tensor: cutoffs = cutoffs * (distances < self.cutoff_upper) return cutoffs + class GatedEquivariantBlock(nn.Module): """Gated Equivariant Block as defined in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra @@ -407,8 +434,12 @@ def __init__( if intermediate_channels is None: intermediate_channels = hidden_channels - self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False, dtype=dtype) + self.vec1_proj = nn.Linear( + hidden_channels, hidden_channels, bias=False, dtype=dtype + ) + self.vec2_proj = nn.Linear( + hidden_channels, out_channels, bias=False, dtype=dtype + ) act_class = act_class_mapping[activation] self.update_net = nn.Sequential( @@ -432,7 +463,10 @@ def forward(self, x, v): # detach zero-entries to avoid NaN gradients during force loss backpropagation vec1 = torch.zeros( - vec1_buffer.size(0), vec1_buffer.size(2), device=vec1_buffer.device, dtype=vec1_buffer.dtype + vec1_buffer.size(0), + vec1_buffer.size(2), + device=vec1_buffer.device, + dtype=vec1_buffer.dtype, ) mask = (vec1_buffer != 0).view(vec1_buffer.size(0), -1).any(dim=1) if not mask.all(): @@ -455,6 +489,53 @@ def forward(self, x, v): x = self.act(x) return x, v + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + """Broadcasts src to the shape of other along the given dimension.""" + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src + + +def scatter( + src: Tensor, + index: Tensor, + dim: int = 0, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> Tensor: + """Has the signature of torch_scatter.scatter, but uses torch.scatter_reduce instead.""" + if dim_size is None: + dim_size = index.max().item() + 1 + operation_dict = { + "add": "sum", + "sum": "sum", + "mul": "prod", + "mean": "mean", + "min": "amin", + "max": "amax", + } + reduce_op = operation_dict[reduce] + # take into account the dimensionality of src + index = _broadcast(index, src, dim) + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + res = out.scatter_reduce(dim, index, src, reduce_op) + return res + + rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing} act_class_mapping = { diff --git a/torchmdnet/priors/coulomb.py b/torchmdnet/priors/coulomb.py index 45972f0e5..faa47717f 100644 --- a/torchmdnet/priors/coulomb.py +++ b/torchmdnet/priors/coulomb.py @@ -1,7 +1,6 @@ import torch from torchmdnet.priors.base import BasePrior -from torch_scatter import scatter -from torchmdnet.models.utils import OptimizedDistance +from torchmdnet.models.utils import OptimizedDistance, scatter from typing import Optional, Dict class Coulomb(BasePrior): diff --git a/torchmdnet/priors/d2.py b/torchmdnet/priors/d2.py index 953e5a54b..e1aa563fe 100644 --- a/torchmdnet/priors/d2.py +++ b/torchmdnet/priors/d2.py @@ -1,7 +1,6 @@ from torchmdnet.priors.base import BasePrior -from torchmdnet.models.utils import OptimizedDistance +from torchmdnet.models.utils import OptimizedDistance, scatter import torch as pt -from torch_scatter import scatter class D2(BasePrior): @@ -103,8 +102,8 @@ class D2(BasePrior): [31.74, 1.892], # 52 Te [31.50, 1.892], # 53 I [29.99, 1.881], # 54 Xe - ], dtype=pt.float64 - + ], + dtype=pt.float64, ) C_6_R_r[:, 1] *= 0.1 # Å --> nm @@ -157,15 +156,15 @@ def get_init_args(self): "max_num_neighbors": self.max_num_neighbors, "atomic_number": self.atomic_number, "distance_scale": self.distance_scale, - "energy_scale": self.energy_scale + "energy_scale": self.energy_scale, } def post_reduce(self, y, z, pos, batch, extra_args): # Convert to interal units: nm and J/mol # NOTE: float32 is overflowed, if m and J are used - distance_scale = self.distance_scale*1e9 # m --> nm - energy_scale = self.energy_scale*6.02214076e23 # J --> J/mol + distance_scale = self.distance_scale * 1e9 # m --> nm + energy_scale = self.energy_scale * 6.02214076e23 # J --> J/mol # Get atom pairs and their distancence ij, R_ij, _ = self.distances(pos, batch) diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py index 2e896ea4e..a5eeb45f1 100644 --- a/torchmdnet/priors/zbl.py +++ b/torchmdnet/priors/zbl.py @@ -1,9 +1,9 @@ import torch from torchmdnet.priors.base import BasePrior -from torch_scatter import scatter -from torchmdnet.models.utils import OptimizedDistance, CosineCutoff +from torchmdnet.models.utils import OptimizedDistance, CosineCutoff, scatter from typing import Optional, Dict + class ZBL(BasePrior): """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It @@ -16,7 +16,16 @@ class ZBL(BasePrior): distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol) """ - def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): + + def __init__( + self, + cutoff_distance, + max_num_neighbors, + atomic_number=None, + distance_scale=None, + energy_scale=None, + dataset=None, + ): super(ZBL, self).__init__() if atomic_number is None: atomic_number = dataset.atomic_number @@ -26,7 +35,9 @@ def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, dista energy_scale = dataset.energy_scale atomic_number = torch.as_tensor(atomic_number, dtype=torch.long) self.register_buffer("atomic_number", atomic_number) - self.distance = OptimizedDistance(0, cutoff_distance, max_num_pairs=-max_num_neighbors) + self.distance = OptimizedDistance( + 0, cutoff_distance, max_num_pairs=-max_num_neighbors + ) self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) self.cutoff_distance = cutoff_distance self.max_num_neighbors = max_num_neighbors @@ -34,30 +45,47 @@ def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, dista self.energy_scale = float(energy_scale) def get_init_args(self): - return {'cutoff_distance': self.cutoff_distance, - 'max_num_neighbors': self.max_num_neighbors, - 'atomic_number': self.atomic_number.tolist(), - 'distance_scale': self.distance_scale, - 'energy_scale': self.energy_scale} + return { + "cutoff_distance": self.cutoff_distance, + "max_num_neighbors": self.max_num_neighbors, + "atomic_number": self.atomic_number.tolist(), + "distance_scale": self.distance_scale, + "energy_scale": self.energy_scale, + } def reset_parameters(self): pass - def post_reduce(self, y, z, pos, batch, extra_args: Optional[Dict[str, torch.Tensor]]): + def post_reduce( + self, y, z, pos, batch, extra_args: Optional[Dict[str, torch.Tensor]] + ): edge_index, distance, _ = self.distance(pos, batch) if edge_index.shape[1] == 0: return y atomic_number = self.atomic_number[z[edge_index]] # 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential. - a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23) - d = distance*self.distance_scale/a - f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) + a = ( + 0.8854 + * 5.29177210903e-11 + / (atomic_number[0] ** 0.23 + atomic_number[1] ** 0.23) + ) + d = distance * self.distance_scale / a + f = ( + 0.1818 * torch.exp(-3.2 * d) + + 0.5099 * torch.exp(-0.9423 * d) + + 0.2802 * torch.exp(-0.4029 * d) + + 0.02817 * torch.exp(-0.2016 * d) + ) f *= self.cutoff(distance) # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair # appears twice. - energy = f*atomic_number[0]*atomic_number[1]/distance - energy = 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum") + energy = f * atomic_number[0] * atomic_number[1] / distance + energy = ( + 0.5 + * (2.30707755e-28 / self.energy_scale / self.distance_scale) + * scatter(energy, batch[edge_index[0]], dim=0, reduce="sum") + ) if energy.shape[0] < y.shape[0]: - energy = torch.nn.functional.pad(energy, (0, y.shape[0]-energy.shape[0])) + energy = torch.nn.functional.pad(energy, (0, y.shape[0] - energy.shape[0])) energy = energy.reshape(y.shape) return y + energy From c3b46f63df91213565bc2f7a782f10cc64b96cf6 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 14 Nov 2023 16:28:50 +0100 Subject: [PATCH 5/7] change trainer strategy to "auto" (#234) --- torchmdnet/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 11cfed0f6..dfdc4ac37 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -173,7 +173,7 @@ def main(): ) trainer = pl.Trainer( - strategy=DDPStrategy(find_unused_parameters=False), + strategy="auto", max_epochs=args.num_epochs, accelerator="auto", devices=args.ngpus, From be6a2142d4bbcc4be98f180f3df6f44651682fbb Mon Sep 17 00:00:00 2001 From: Raul Date: Thu, 16 Nov 2023 12:35:51 +0100 Subject: [PATCH 6/7] Improve documentation (#224) * Improve TorchMD_NET documentation Add some notes about charge/spin * Add charge and spin to the example yamls, to highlight their existence * Add some docstrings * Add first ReadTheDocs version * Remove unused lines in conf * Move some stuff to Usage * Update docs * Move options to another file so they can be autodocumented * Update docs * Update docs * Hide utility classes in transformer * Update some docstrings * Update conf.py * Add style.css * Document ACE Dataset * Document ANI * Revert _MultiHeadAttention to MultiHeadAttention * Update ace documentation * Update ani and atomref docs * Update ace docs * Add extension mock import * Update comment * Add docstring to extension * Add docstring for extensions * Revert "Move options to another file so they can be autodocumented" This reverts commit c61c92246142f3deba28a2c59a86c97a7782be85. * Store model names in model.__all_models__ * Set __all__ for some files * Fix External docstring * Make docs respect __all__ * Fix docstring * Update docs * Update doc build * Move readthedocs to root * Update readthedocs * Add environemnt.yml * Install local master version for docs * Update requirements * Remove env * Fix typo * Hide classes with :meta private: * Add intersphinx * Add version with git * Blacken * Change title * Small changes * Small changes * Default tag to master * Add priors * Add workflow to test documentation build * Update ci * Update ci * Update ci * Update ci * Update ci * Add documentation badge to README * Update README --- .github/workflows/docs_build.yaml | 36 +++++ .gitignore | 8 ++ .readthedocs.yaml | 17 +++ README.md | 7 + docs/Makefile | 20 +++ docs/_static/style.css | 59 ++++++++ docs/make.bat | 35 +++++ docs/requirements.txt | 5 + docs/source/api.rst | 10 ++ docs/source/conf.py | 88 ++++++++++++ docs/source/datasets.rst | 81 +++++++++++ docs/source/index.rst | 72 ++++++++++ docs/source/installation.rst | 60 ++++++++ docs/source/models.rst | 134 ++++++++++++++++++ docs/source/priors.rst | 82 +++++++++++ docs/source/torchmd-train.rst | 95 +++++++++++++ docs/source/usage.rst | 67 +++++++++ examples/ET-ANI1.yaml | 2 + examples/ET-MD17.yaml | 2 + examples/ET-QM9.yaml | 3 + examples/ET-SPICE.yaml | 2 + examples/TensorNet-ANI1X.yaml | 2 + examples/TensorNet-QM9.yaml | 2 + examples/TensorNet-SPICE.yaml | 2 + examples/TensorNet-rMD17.yaml | 2 + tests/test_model.py | 14 +- tests/test_module.py | 4 +- tests/test_priors.py | 2 +- tests/test_wrappers.py | 2 +- torchmdnet/calculators.py | 4 +- torchmdnet/data.py | 16 +++ torchmdnet/datasets/ace.py | 128 +++++++++++++++++ torchmdnet/datasets/ani.py | 56 +++++++- torchmdnet/datasets/qm9q.py | 11 +- torchmdnet/extensions/__init__.py | 96 +++++++++++-- .../extensions/neighbors/neighbors_cuda.cu | 1 - torchmdnet/models/__init__.py | 2 +- torchmdnet/models/model.py | 120 +++++++++++----- torchmdnet/models/output_modules.py | 5 + torchmdnet/models/tensornet.py | 31 ++-- torchmdnet/models/torchmd_et.py | 4 + torchmdnet/models/torchmd_gn.py | 10 +- torchmdnet/models/torchmd_t.py | 6 + torchmdnet/models/utils.py | 87 +++++++----- torchmdnet/priors/atomref.py | 7 + torchmdnet/priors/coulomb.py | 22 ++- torchmdnet/priors/d2.py | 56 ++++---- torchmdnet/priors/zbl.py | 27 ++-- torchmdnet/scripts/train.py | 14 +- 49 files changed, 1442 insertions(+), 176 deletions(-) create mode 100644 .github/workflows/docs_build.yaml create mode 100644 .readthedocs.yaml create mode 100644 docs/Makefile create mode 100644 docs/_static/style.css create mode 100644 docs/make.bat create mode 100644 docs/requirements.txt create mode 100644 docs/source/api.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/datasets.rst create mode 100644 docs/source/index.rst create mode 100644 docs/source/installation.rst create mode 100644 docs/source/models.rst create mode 100644 docs/source/priors.rst create mode 100644 docs/source/torchmd-train.rst create mode 100644 docs/source/usage.rst diff --git a/.github/workflows/docs_build.yaml b/.github/workflows/docs_build.yaml new file mode 100644 index 000000000..234ef1851 --- /dev/null +++ b/.github/workflows/docs_build.yaml @@ -0,0 +1,36 @@ +name: Build Documentation + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + + +jobs: + build-docs: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Env + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: environment.yml + init-shell: bash + generate-run-shell: true + + - name: Install docs dependencies + run: | + pip install -vv . + pip install -r docs/requirements.txt + shell: bash -el {0} + + - name: Build Sphinx Documentation + run: | + cd docs + make html + shell: bash -el {0} diff --git a/.gitignore b/.gitignore index 2a05ba9d9..4caa48eef 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,11 @@ dmypy.json # temp directories logs/ + +# Docs +docs/build/ +docs/source/generated + +# Extra + +*~ \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..06d6a8f3e --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: "2" + +build: + os: "ubuntu-22.04" + tools: + python: "mambaforge-22.9" + jobs: + post_create_environment: + - pip install -r docs/requirements.txt + - pip install . + +conda: + environment: environment.yml + + +sphinx: + configuration: docs/source/conf.py diff --git a/README.md b/README.md index ed9512f30..fe1758483 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,16 @@ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![CI](https://github.com/torchmd/torchmd-net/actions/workflows/CI.yml/badge.svg)](https://github.com/torchmd/torchmd-net/actions/workflows/CI.yml) +[![Documentation Status](https://readthedocs.org/projects/torchmd-net/badge/?version=latest)](https://torchmd-net.readthedocs.io/en/latest/?badge=latest) # TorchMD-NET TorchMD-NET provides state-of-the-art neural networks potentials (NNPs) and a mechanism to train them. It offers efficient and fast implementations if several NNPs and it is integrated in GPU-accelerated molecular dynamics code like [ACEMD](https://www.acellera.com/products/molecular-dynamics-software-gpu-acemd/), [OpenMM](https://www.openmm.org) and [TorchMD](https://github.com/torchmd/torchmd). TorchMD-NET exposes its NNPs as [PyTorch](https://pytorch.org) modules. + +## Documentation + +Documentation is available at https://torchmd-net.readthedocs.io + ## Available architectures - [Equivariant Transformer (ET)](https://arxiv.org/abs/2202.02541) @@ -13,6 +19,7 @@ TorchMD-NET provides state-of-the-art neural networks potentials (NNPs) and a me - [TensorNet](https://arxiv.org/abs/2306.06482) + ## Installation TorchMD-Net is available in [conda-forge](https://conda-forge.org/) and can be installed with: ```shell diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..d0c3cbf10 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/style.css b/docs/_static/style.css new file mode 100644 index 000000000..2580b0066 --- /dev/null +++ b/docs/_static/style.css @@ -0,0 +1,59 @@ +@import 'theme.css'; + +.rst-content dl:not(.docutils) dt:first-child { + margin-top: 0; +} + +.rst-content dl:not(.docutils) dl dt { + margin-bottom: 4px; + border: none; + border-left: solid 3px #ccc; + background: #f0f0f0; + color: #555; +} + +.rst-content dl table, +.rst-content dl ul, +.rst-content dl ol, +.rst-content dl p { + margin-bottom: 8px !important; +} + +.rst-content dl:not(.docutils) dt { + display: table; + margin: 6px 0; + font-size: 90%; + line-height: normal; + background: #e7f2fa; + color: #2980B9; + border-top: solid 3px #6ab0de; + padding: 6px; + position: relative; +} + +html.writer-html5 .rst-content dl.field-list { + display: initial; +} + +html.writer-html5 .rst-content dl.field-list>dd, +html.writer-html5 .rst-content dl.field-list>dt { + margin-bottom: 4px; + padding-left: 6px; +} + +p { + line-height: 20px; + font-size: 14px; +} + +html.writer-html5 .rst-content dl.field-list>dt:after { + content: initial; +} + +dt.field-even { + text-transform: uppercase; +} + +dt.field-odd { + text-transform: uppercase; +} diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..6247f7e23 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..8b776d3fd --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,5 @@ +sphinx==7.2.6 +sphinx-rtd-theme==1.3.0 +sphinxcontrib-autoprogram==0.1.8 +sphinxcontrib-napoleon==0.7 +gitpython diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 000000000..08e05ba60 --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,10 @@ +API Reference +============= + +.. autosummary:: + :toctree: generated + :recursive: + + torchmdnet + + diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 000000000..a95087e39 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,88 @@ +# Configuration file for the Sphinx documentation builder. + +# -- Project information +project = "TorchMD-Net" +author = "RaulPPelaez" + +import git + + +def get_latest_git_tag(repo_path="."): + repo = git.Repo(repo_path) + tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime) + return tags[-1].name if tags else None + + +current_tag = get_latest_git_tag("../../") +if current_tag is None: + current_tag = "master" +release = current_tag +version = current_tag + +# -- General configuration +extensions = [ + "sphinx.ext.duration", + "sphinx.ext.doctest", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinxcontrib.autoprogram", +] +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = True +napoleon_use_admonition_for_notes = True +napoleon_use_admonition_for_references = True +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = True +napoleon_type_aliases = None +napoleon_attr_annotations = True +autosummary_ignore_module_all = False + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "sphinx": ("https://www.sphinx-doc.org/en/master/", None), +} +intersphinx_disabled_domains = ["std"] + +templates_path = ["_templates"] + +# -- Options for HTML output + +html_theme = "sphinx_rtd_theme" + +# -- Options for EPUB output +epub_show_urls = "footnote" + +autoclass_content = "both" +autodoc_typehints = "none" +autodoc_inherit_docstrings = False +sphinx_autodoc_typehints = True +html_show_sourcelink = True +autodoc_default_options = { + "members": True, + "member-order": "bysource", + "exclude-members": "__weakref__", + "undoc-members": False, + "show-inheritance": True, + "inherited-members": False, +} +# Exclude all torchmdnet.datasets.*.rst files in source/generated/ +exclude_patterns = [ + "generated/torchmdnet.datasets.*.rst", + "generated/torchmdnet.scripts.*rst", +] +html_static_path = ["../_static"] +html_css_files = [ + "style.css", +] + +autodoc_mock_imports = ["torchmdnet.extensions.torchmdnet_extensions"] diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst new file mode 100644 index 000000000..86fa22eee --- /dev/null +++ b/docs/source/datasets.rst @@ -0,0 +1,81 @@ +Datasets +======== +.. toctree:: + :maxdepth: 3 + + + +Using a Dataset +--------------- + +There are two ways of using a dataset in TorchMD-Net, depending on whether you are using the Python API (for instance to run inferences) or the command line interface (for training). + +Via a configuration file +~~~~~~~~~~~~~~~~~~~~~~~~ + +You can make use of any of the available datasets via the :ref:`configuration file ` if you are using the :ref:`torchmd-train utility `. +Take a look at one of the example configuration files. As an example lets set up the :py:mod:`QM9` dataset, which allows for an additional argument `labels` specifying the subset to be provided. The part of the yaml configuration file for the dataset would look like this: + +.. code:: yaml + + dataset: QM9 + dataset_arg: + label: energy_U0 + + + +Via the Python API +~~~~~~~~~~~~~~~~~~ + +TorchMD-Net datasets are inherited from `torch Geometric datasets `_, you may use them whenever the PyTorch Geometric datasets are used. For instance, to load the QM9 dataset, you may use the following code: + +.. code:: python + + from torchmdnet.datasets import QM9 + from torchmdnet.data import DataModule + dataset = QM9(root='data', labels='energy_U0') + + print(dataset[0]) + print(len(dataset)) + # Some arbitrary parameters for the DataModule + params = {'batch_size': 32, + 'inference_batch_size': 32, + 'num_workers': 4, + 'train_size': 0.8, + 'val_size': 0.1, + 'test_size': 0.1, + 'seed': 42, + 'log_dir': 'logs', + 'splits': None, + 'standardize': False,} + + dataloader = DataModule(params, dataset) + dataloader.prepare_data() + dataloader.setup("fit") + + # You can use this directly with PyTorch Lightning + # trainer.fit(model, dataloader) + + +Adding a new Dataset +-------------------- + +If you want to train on custom data, first have a look at :py:mod:`torchmdnet.datasets.Custom`, which provides functionalities for loading a NumPy dataset consisting of atom types and coordinates, as well as energies, forces or both as the labels. + +To add a new dataset, you need to: + +1. Write a new class inheriting from :py:mod:`torch_geometric.data.Dataset`, see `here `_ for a tutorial on that. +2. Add the new class to the :py:mod:`torchmdnet.datasets` module (by listing it in the `__all__` variable in the `datasets/__init__.py` file) so that the :ref:`configuration file ` recognizes it. + +.. note:: + + The dataset must return torch-geometric `Data` objects, containing at least the keys `z` (atom types) and `pos` (atomic coordinates), as well as `y` (label), `neg_dy` (negative derivative of the label w.r.t atom coordinates) or both. + +Available Datasets +------------------ + +.. automodule:: torchmdnet.datasets + :noindex: + + .. include:: generated/torchmdnet.datasets.rst + :start-line: 5 diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 000000000..d32f07f20 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,72 @@ +Welcome to the TorchMD-NET Documentation! +========================================= + +TorchMD-NET provides state-of-the-art neural networks potentials (NNPs) and a mechanism to train them. It offers efficient and fast implementations of several NNPs and is integrated with GPU-accelerated molecular dynamics code like `ACEMD `_, `OpenMM `_, and `TorchMD `_. TorchMD-NET exposes its NNPs as `PyTorch `_ modules. + + +Cite +==== + +If you use TorchMD-NET in your research, please cite the following papers: + +Main reference +~~~~~~~~~~~~~~ + +.. code-block:: bibtex + + @inproceedings{ + tholke2021equivariant, + title={Equivariant Transformers for Neural Network based Molecular Potentials}, + author={Philipp Th{\"o}lke and Gianni De Fabritiis}, + booktitle={International Conference on Learning Representations}, + year={2022}, + url={https://openreview.net/forum?id=zNHzqZ9wrRB} + } + +Graph Network +~~~~~~~~~~~~~ + +.. code-block:: bibtex + + @misc{majewski2022machine, + title={Machine Learning Coarse-Grained Potentials of Protein Thermodynamics}, + author={Maciej Majewski and Adrià Pérez and Philipp Thölke and Stefan Doerr and Nicholas E. Charron and Toni Giorgino and Brooke E. Husic and Cecilia Clementi and Frank Noé and Gianni De Fabritiis}, + year={2022}, + eprint={2212.07492}, + archivePrefix={arXiv}, + primaryClass={q-bio.BM} + } + +TensorNet +~~~~~~~~~ + +.. code-block:: bibtex + + @misc{simeon2023tensornet, + title={TensorNet: Cartesian Tensor Representations for Efficient Learning of Molecular Potentials}, + author={Guillem Simeon and Gianni de Fabritiis}, + year={2023}, + eprint={2306.06482}, + archivePrefix={arXiv}, + primaryClass={cs.LG} + } + + +.. toctree:: + :hidden: + + installation + usage + torchmd-train + datasets + models + priors + api + +.. + Indices and tables + ================== + + * :ref:`genindex` + * :ref:`modindex` + * :ref:`search` diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 000000000..bc0c3f914 --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,60 @@ +Installation +============ + +TorchMD-Net is available in `conda-forge `_ and can be installed with: + +.. code-block:: shell + + mamba install torchmd-net + +We recommend using `Mamba `_ instead of conda. + +Install from source +------------------- + +1. Clone the repository: + + .. code-block:: shell + + git clone https://github.com/torchmd/torchmd-net.git + cd torchmd-net + +2. Install the dependencies in environment.yml. You can do it via pip, but we recommend `Mambaforge `_ instead. + +3. Create an environment and activate it: + + .. code-block:: shell + + mamba env create -f environment.yml + mamba activate torchmd-net + +4. Install TorchMD-NET into the environment: + + .. code-block:: shell + + pip install -e . + +This will install TorchMD-NET in editable mode, so that changes to the source code are immediately available. +Besides making all python utilities available environment-wide, this will also install the ``torchmd-train`` command line utility. + +CUDA enabled installation +------------------------- + +Besides the dependencies listed in the environment file, you will also need the CUDA ``nvcc`` compiler suite to build TorchMD-Net. +If your system lacks nvcc you may install it via conda-forge: + +.. code-block:: shell + + mamba install cudatoolkit-dev + +Or from the nvidia channel: + +.. code-block:: shell + + mamba install -c nvidia cuda-nvcc cuda-cudart-dev cuda-libraries-dev + +Make sure you install a major version compatible with your torch installation, which you can check with: + +.. code-block:: shell + + python -c "import torch; print(torch.version.cuda)" diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 000000000..5761d92c1 --- /dev/null +++ b/docs/source/models.rst @@ -0,0 +1,134 @@ +Neural Network Potentials +========================= + + +Training a model +---------------- + +The typical workflow to obtain a neural network potential in TorchMD-Net starts with :ref:`training ` one of the `Available Models`_. During this process you will get a checkpoint file that can be used to load the model for inference. + + + +Loading a model for inference +----------------------------- + +Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" + model = load_model(checkpoint, derivative=True) + # An arbitrary set of inputs for the model + n_atoms = 10 + zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) + z = zs[torch.randint(0, len(zs), (n_atoms,))] + pos = torch.randn(len(z), 3) + batch = torch.zeros(len(z), dtype=torch.long) + + y, neg_dy = model(z, pos, batch) + +.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. + + + +Available Models +---------------- + +TorchMD-Net offers representation models that output a series of per-atom features. Typically one wants to couple this with an :py:mod:`output model ` and perhaps a :py:mod:`prior ` to get a single per-batch label (i.e. total energy) and optionally its derivative with respect to the positions (i.e. forces). + +The :py:mod:`TorchMD_Net ` model takes care of putting the pieces together. + + +TensorNet +~~~~~~~~~ + +TensorNet is an equivariant model based on rank-2 Cartesian tensor representations. Euclidean neural network potentials have been shown to achieve state-of-the-art performance and better data efficiency than previous models, relying on higher-rank equivariant features which are irreducible representations of the rotation group, in the form of spherical tensors. However, the computation of tensor products in these models can be computationally demanding. In contrast, TensorNet exploits the use of Cartesian rank-2 tensors (3x3 matrices) which can be very efficiently decomposed into scalar, vector and rank-2 tensor features. Furthermore, Clebsch-Gordan tensor products are substituted by straightforward 3x3 matrix products. Overall, these properties allow TensorNet to achieve state-of-the-art accuracy on common benchmark datasets such as rMD17 and QM9 with a reduced number of message passing steps, learnable parameters and computational cost. The prediction of up to rank-2 molecular properties that behave appropriately under geometric transformations such as reflections and rotations is also possible. + +.. automodule:: torchmdnet.models.tensornet + :noindex: + +.. note:: TensorNet is referred to as "tensornet" in the :ref:`configuration-file`. + +Equivariant Transformer +~~~~~~~~~~~~~~~~~~~~~~~ + +The Equivariant Transformer (ET) is an equivariant neural network which uses both scalar and Cartesian vector representations. The distinctive feature of the ET in comparison to other Cartesian vector models such as PaiNN or EGNN is the use of a distance-dependent dot product attention mechanism, which achieved state-of-the-art performance on benchmark datasets at the time of publication. Furthermore, the analysis of attention weights allowed to extract insights on the interaction of different atomic species for the prediction of molecular energies and forces. The model also exhibits a low computational cost for inference and training in comparison to some of the most used NNPs in the literature. + +.. automodule:: torchmdnet.models.torchmd_et + :noindex: + + +.. note:: Equivariant Transformer is referred to as "equivariant-transformer" in the :ref:`configuration-file`. + +Graph Network +~~~~~~~~~~~~~ + +The graph network is an invariant model inspired on the SchNet and PhysNet architectures. The network was optimized to have satisfactory performance on coarse-grained proteins, allowing to build NNPs that correctly reproduce protein free energy landscapes. In contrast to the ET and TensorNet, the graph network only uses relative distances between atoms as geometrical information, which are invariant to translations, rotations and reflections. The distances are used by the model to learn a set of continuous filters that are applied to feature graph convolutions as in SchNet, progressively updating the intial atomic embeddings by means of residual connections. + +.. automodule:: torchmdnet.models.torchmd_gn + :noindex: + + +.. note:: Graph Network is referred to as "graph-network" in the :ref:`configuration-file`. + +Implementing a new Architecture +------------------------------- + +To implement a new architecture, you need to follow these steps: + +1. Create a new class in ``torchmdnet.models`` that inherits from ``torch.nn.Model``. Follow TorchMD_ET as a template. This is a minimum implementation of a model: + + .. code-block:: python + + class MyModule(nn.Module): + def __init__(self, parameter1, parameter2): + super(MyModule, self).__init__() + # Define your model here + self.layer1 = nn.Linear(10, 10) + # Initialize your model parameters here + self.reset_parameters() + + def reset_parameters(self): + # Initialize your model parameters here + nn.init.xavier_uniform_(self.layer1.weight) + + def forward(self, z: Tensor, pos: Tensor, batch: Tensor, q: Optional[Tensor] = None, s: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + # Define your forward pass here + scalar_features = ... + vector_features = ... + return scalar_features, vector_features, z, pos, batch + +2. Add the model to the ``__all__`` list in ``torchmdnet.models.__init__.py``. This will make the tests pick your model up. + +3. Tell models.model.create_model how to initialize your module by adding a new entry: + + .. code-block:: python + + elif args["model"] == "mymodule": + from torchmdnet.models.torchmd_mymodule import MyModule + is_equivariant = False + representation_model = MyModule( + parameter1=args["parameter1"], + parameter2=args["parameter2"], + **shared_args, + ) + +4. Add any new parameters required to initialize your module to scripts.train.get_args: + + .. code-block:: python + + parser.add_argument('--parameter1', type=int, default=32, help='Parameter1 required by MyModule') + +5. Add an example configuration file to ``torchmd-net/examples`` that uses your model. + +6. Make tests use your configuration file by adding a case to tests.utils.load_example_args: + + .. code-block:: python + + if model_name == "mymodule": + config_file = join(dirname(dirname(__file__)), "examples", "MyModule-QM9.yaml") + +At this point, if your module is missing some feature the tests will let you know, and you can add it. If you add a new feature to the package, please add a test for it. + diff --git a/docs/source/priors.rst b/docs/source/priors.rst new file mode 100644 index 000000000..e12fe9b01 --- /dev/null +++ b/docs/source/priors.rst @@ -0,0 +1,82 @@ +Priors +====== + +.. toctree:: + :maxdepth: 4 + + +Priors in the context of TorchMD-Net are pre-defined models that embed domain-specific knowledge into the neural network. They are used to enforce known physical laws or empirical observations that are not automatically learned by the neural network. This inclusion enhances the predictive accuracy of the network, particularly in scenarios where training data may be limited or where the network needs to generalize beyond the scope of its training set. + +The primary role of priors is to guide the learning process of the network by imposing constraints based on physical principles. For example, a prior might reflect known chemical properties or sum a Coulomb interaction to the energy predicted by the network. + + +Using Priors in TorchMD-Net +--------------------------- + +There are two ways of using a dataset in TorchMD-Net, depending on whether you are using the Python API (for instance to run inferences) or the command line interface (for training). + +Via a configuration file +~~~~~~~~~~~~~~~~~~~~~~~~ + +You can make use of any of the available priors via the :ref:`configuration file ` if you are using the :ref:`torchmd-train utility `. +In the YAML configuration file, you can specify the type of prior model to use along with any additional arguments it requires (see the documentation for each particular prior). + +For example, to use the Atomref prior via YAML, your configuration might look like this: + +.. code:: yaml + + prior_model: Atomref + prior_args: + max_z: 100 # Optional argument for Atomref + + +Via the Python API +~~~~~~~~~~~~~~~~~~ + +If you are using the Python API, you can use any of the available priors by importing them from the :mod:`torchmdnet.priors` module and passing them to the :class:`torchmdnet.models.TorchMDNet` class. + +Writing a new Prior +-------------------- + +All Priors inherit from the :py:mod:`BasePrior` class. + +As an example, lets write a prior that adds an offset to the energy of each atom and molecule. We will call it :py:class:`EnergyOffset`. + +.. code:: python + + from torchmdnet.priors.base import BasePrior + class EnergyOffset(BasePrior): + + def __init__(self, atom_offset=0, molecule_offset=0, dataset=None): + super().__init__() + self.atom_offset = atom_offset + self.molecule_offset = molecule_offset + + def get_init_args(self): + r"""A function that returns all required arguments to construct a prior object. + The values should be returned inside a dict with the keys being the arguments' names. + All values should also be saveable in a .yaml file as this is used to reconstruct the + prior model from a checkpoint file. + """ + return {"atom_offset": self.atom_offset, "molecule_offset": self.molecule_offset} + + def pre_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Tensor, extra_args: Optional[Dict[str, Tensor]]): + """Adds the offset to the energy of each atom. + """ + return x + self.atom_offset + + def post_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Tensor, extra_args: Optional[Dict[str, Tensor]]): + """Adds the offset to the energy of each molecule. + """ + return x + self.molecule_offset + + +Available Priors +---------------- + +.. automodule:: torchmdnet.priors + :noindex: + + .. include:: generated/torchmdnet.priors.rst + :start-line: 5 + diff --git a/docs/source/torchmd-train.rst b/docs/source/torchmd-train.rst new file mode 100644 index 000000000..b12e76443 --- /dev/null +++ b/docs/source/torchmd-train.rst @@ -0,0 +1,95 @@ +.. _torchmd-train: + +TorchMD-Train Utility +--------------------- + +.. _configuration-file: + +Configuration file +~~~~~~~~~~~~~~~~~~ + +The torchmd-train utility can be configured via a `yaml `_ file, see below for a list of available options. You can include any valid option in the yaml file by replacing "-" by "_", for instance: + +.. code:: yaml + + activation: silu + aggr: add + atom_filter: -1 + batch_size: 16 + coord_files: null + cutoff_lower: 0.0 + cutoff_upper: 5.0 + dataset: QM9 + dataset_arg: + label: energy_U0 + dataset_root: ~/data + derivative: false + early_stopping_patience: 150 + ema_alpha_neg_dy: 1.0 + ema_alpha_y: 1.0 + embed_files: null + embedding_dimension: 256 + energy_files: null + equivariance_invariance_group: O(3) + y_weight: 1.0 + force_files: null + neg_dy_weight: 0.0 + gradient_clipping: 40 + inference_batch_size: 128 + load_model: null + log_dir: logs/ + lr: 0.0001 + lr_factor: 0.8 + lr_min: 1.0e-07 + lr_patience: 15 + lr_warmup_steps: 1000 + max_num_neighbors: 64 + max_z: 128 + model: tensornet + ngpus: -1 + num_epochs: 3000 + num_layers: 3 + num_nodes: 1 + num_rbf: 64 + num_workers: 6 + output_model: Scalar + precision: 32 + prior_model: Atomref + rbf_type: expnorm + redirect: false + reduce_op: add + save_interval: 10 + seed: 1 + splits: null + standardize: false + test_interval: 20 + test_size: null + train_size: 110000 + trainable_rbf: false + val_size: 10000 + weight_decay: 0.0 + charge: false + spin: false + +.. note:: There are several example files in the `examples/` folder. + +You can use a yaml configuration file with the `torchmd-train` utility with: + +.. code:: bash + + torchmd-train --conf my_conf.yaml + +.. note:: Flags provided after `--conf` will override the ones in the yaml file. + +.. note:: + + The utility will save all the provided parameters as a file called `input.yaml` along with the generated checkpoints. + +Command line interface +~~~~~~~~~~~~~~~~~~~~~~ + + +.. autoprogram:: scripts.train:get_argparse() + :prog: torchmd-train + + diff --git a/docs/source/usage.rst b/docs/source/usage.rst new file mode 100644 index 000000000..ffb2b34e9 --- /dev/null +++ b/docs/source/usage.rst @@ -0,0 +1,67 @@ +Usage +----- + +.. _training: + +Training an existing model +========================== + +Specifying training arguments can either be done via a :ref:`configuration YAML file ` or through command line arguments directly, see the :ref:`torchmd-train ` utility for more info. Several examples of architectural and training specifications for some models and datasets can be found in `examples `_. + +GPUs can be selected by setting the `CUDA_VISIBLE_DEVICES` environment variable. Otherwise, the argument `--ngpus` can be used to select the number of GPUs to train on (-1, the default, uses all available GPUs or the ones specified in `CUDA_VISIBLE_DEVICES`). + + +.. note:: + + Keep in mind that the `GPU ID reported by nvidia-smi might not be the same as the one CUDA_VISIBLE_DEVICES uses `_. + +For example, to train the Equivariant Transformer on the QM9 dataset with the architectural and training hyperparameters described in the paper, one can run + +.. code:: bash + + mkdir output + CUDA_VISIBLE_DEVICES=0 torchmd-train --conf torchmd-net/examples/ET-QM9.yaml --log-dir output/ + +Run `torchmd-train --help` to see all available options and their descriptions. + +Pretrained Models +================= + +See `here `_ for instructions on how to load pretrained models. + +Custom Prior Models +=================== + +In addition to implementing a custom dataset class, it is also possible to add a custom prior model to the model. This can be done by implementing a new prior model class in :py:mod:`torchmdnet.priors` and adding the argument ``--prior-model ``. As an example, have a look at :py:mod:`torchmdnet.priors.Atomref`. + + +Multi-Node Training +=================== + +In order to train models on multiple nodes some environment variables have to be set, which provide all necessary information to PyTorch Lightning. In the following, we provide an example bash script to start training on two machines with two GPUs each. The script has to be started once on each node. Once ``torchmd-train`` is started on all nodes, a network connection between the nodes will be established using NCCL. + +.. code-block:: shell + + export NODE_RANK=0 + export MASTER_ADDR=hostname1 + export MASTER_PORT=12910 + + mkdir -p output + CUDA_VISIBLE_DEVICES=0,1 torchmd-train --conf torchmd-net/examples/ET-QM9.yaml.yaml --num-nodes 2 --log-dir output/ + +- ``NODE_RANK`` : Integer indicating the node index. Must be `0` for the main node and incremented by one for each additional node. +- ``MASTER_ADDR`` : Hostname or IP address of the main node. The same for all involved nodes. +- ``MASTER_PORT`` : A free network port for communication between nodes. PyTorch Lightning suggests port `12910` as a default. + +.. admonition:: Known Limitations + + - Due to the way PyTorch Lightning calculates the number of required DDP processes, all nodes must use the same number of GPUs. Otherwise training will not start or crash. + - We observe a 50x decrease in performance when mixing nodes with different GPU architectures (tested with RTX 2080 Ti and RTX 3090). + +Developer Guide +--------------- + +Code Style +========== + +We use `black `_. Please run ``black`` on your modified diff --git a/examples/ET-ANI1.yaml b/examples/ET-ANI1.yaml index 0270c48f5..dd6d82737 100644 --- a/examples/ET-ANI1.yaml +++ b/examples/ET-ANI1.yaml @@ -54,3 +54,5 @@ train_size: 0.8 trainable_rbf: false val_size: 0.05 weight_decay: 0.0 +charge: false +spin: false diff --git a/examples/ET-MD17.yaml b/examples/ET-MD17.yaml index 5c38b9eb2..1d39f568d 100644 --- a/examples/ET-MD17.yaml +++ b/examples/ET-MD17.yaml @@ -55,3 +55,5 @@ train_size: 950 trainable_rbf: false val_size: 50 weight_decay: 0.0 +charge: false +spin: false diff --git a/examples/ET-QM9.yaml b/examples/ET-QM9.yaml index 08b27a410..8edffdcbd 100644 --- a/examples/ET-QM9.yaml +++ b/examples/ET-QM9.yaml @@ -56,3 +56,6 @@ trainable_rbf: false val_size: 10000 weight_decay: 0.0 precision: 32 +charge: false +spin: false + diff --git a/examples/ET-SPICE.yaml b/examples/ET-SPICE.yaml index f2e5b189f..612823a37 100644 --- a/examples/ET-SPICE.yaml +++ b/examples/ET-SPICE.yaml @@ -55,3 +55,5 @@ train_size: 0.8 trainable_rbf: false val_size: 0.1 weight_decay: 0.0 +charge: false +spin: false diff --git a/examples/TensorNet-ANI1X.yaml b/examples/TensorNet-ANI1X.yaml index d39c32860..ef9de3421 100644 --- a/examples/TensorNet-ANI1X.yaml +++ b/examples/TensorNet-ANI1X.yaml @@ -53,3 +53,5 @@ train_size: 0.8 trainable_rbf: false val_size: 0.1 weight_decay: 0.0 +charge: false +spin: false diff --git a/examples/TensorNet-QM9.yaml b/examples/TensorNet-QM9.yaml index e6128dd24..67536925b 100644 --- a/examples/TensorNet-QM9.yaml +++ b/examples/TensorNet-QM9.yaml @@ -54,3 +54,5 @@ train_size: 110000 trainable_rbf: false val_size: 10000 weight_decay: 0.0 +charge: false +spin: false diff --git a/examples/TensorNet-SPICE.yaml b/examples/TensorNet-SPICE.yaml index 9aa825819..94d27fc4b 100644 --- a/examples/TensorNet-SPICE.yaml +++ b/examples/TensorNet-SPICE.yaml @@ -55,3 +55,5 @@ train_size: 0.8 trainable_rbf: false val_size: 0.1 weight_decay: 0.0 +charge: false +spin: false diff --git a/examples/TensorNet-rMD17.yaml b/examples/TensorNet-rMD17.yaml index 44afdd9b6..e2bf482fa 100644 --- a/examples/TensorNet-rMD17.yaml +++ b/examples/TensorNet-rMD17.yaml @@ -54,3 +54,5 @@ train_size: 950 trainable_rbf: false val_size: 50 weight_decay: 0.0 +charge: false +spin: false diff --git a/tests/test_model.py b/tests/test_model.py index ed3a409d4..1e6241b66 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,7 +12,7 @@ from utils import load_example_args, create_example_batch -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) @mark.parametrize("use_batch", [True, False]) @mark.parametrize("explicit_q_s", [True, False]) @mark.parametrize("precision", [32, 64]) @@ -27,7 +27,7 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): model(z, pos, batch=batch) -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) @mark.parametrize("output_model", output_modules.__all__) @mark.parametrize("precision", [32,64]) def test_forward_output_modules(model_name, output_model, precision): @@ -37,7 +37,7 @@ def test_forward_output_modules(model_name, output_model, precision): model(z, pos, batch=batch) -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript(model_name, device): if device == "cuda" and not torch.cuda.is_available(): @@ -57,7 +57,7 @@ def test_torchscript(model_name, device): grad_outputs=grad_outputs, )[0] -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript_dynamic_shapes(model_name, device): if device == "cuda" and not torch.cuda.is_available(): @@ -124,7 +124,7 @@ def test_cuda_graph_compatible(model_name): assert torch.allclose(y, y2) assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) pl.seed_everything(1234) @@ -135,7 +135,7 @@ def test_seed(model_name): for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed." -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) @mark.parametrize( "output_model", output_modules.__all__, @@ -188,7 +188,7 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ) -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) def test_gradients(model_name): pl.seed_everything(1234) precision = 64 diff --git a/tests/test_module.py b/tests/test_module.py index eba599cfb..4e2f4c762 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -12,7 +12,7 @@ from utils import load_example_args, DummyDataset -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) def test_create_model(model_name): LNNP(load_example_args(model_name), prior_model=Atomref(100)) @@ -22,7 +22,7 @@ def test_load_model(): load_model(checkpoint) -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) @mark.parametrize("use_atomref", [True, False]) @mark.parametrize("precision", [32, 64]) def test_train(model_name, use_atomref, precision, tmpdir): diff --git a/tests/test_priors.py b/tests/test_priors.py index 260391c82..beb52558b 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -12,7 +12,7 @@ import tempfile -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) def test_atomref(model_name): dataset = DummyDataset(has_atomref=True) atomref = Atomref(max_z=100, dataset=dataset) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 11fb5432b..e14a215c1 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -7,7 +7,7 @@ @mark.parametrize("remove_threshold", [-1, 2, 5]) -@mark.parametrize("model_name", models.__all__) +@mark.parametrize("model_name", models.__all_models__) def test_atom_filter(remove_threshold, model_name): # wrap a representation model using the AtomFilter wrapper model = create_model(load_example_args(model_name, remove_prior=True)) diff --git a/torchmdnet/calculators.py b/torchmdnet/calculators.py index f936dbe38..735eefa4f 100644 --- a/torchmdnet/calculators.py +++ b/torchmdnet/calculators.py @@ -19,8 +19,8 @@ class External: - """ - This is an adapter to use TorchMD-Net models in TorchMD. + """This is an adapter to use TorchMD-Net models in TorchMD. + Parameters ---------- netfile : str or torch.nn.Module diff --git a/torchmdnet/data.py b/torchmdnet/data.py index 19c812f89..8a47962f3 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -13,6 +13,9 @@ class FloatCastDatasetWrapper(Dataset): + """A wrapper around a torch_geometric dataset that casts all floating point + tensors to a given dtype. + """ def __init__(self, dataset, dtype=torch.float64): super(FloatCastDatasetWrapper, self).__init__( dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter @@ -40,6 +43,16 @@ def __getattr__(self, name): class DataModule(LightningDataModule): + """A LightningDataModule for loading datasets from the torchmdnet.datasets module. + + Args: + hparams (dict): A dictionary containing the hyperparameters of the + dataset. See the documentation of the torchmdnet.datasets module + for details. + dataset (torch_geometric.data.Dataset): A dataset to use instead of + loading a new one from the torchmdnet.datasets module. + """ + def __init__(self, hparams, dataset=None): super(DataModule, self).__init__() self.save_hyperparameters(hparams) @@ -104,16 +117,19 @@ def test_dataloader(self): @property def atomref(self): + """Returns the atomref of the dataset if it has one, otherwise None.""" if hasattr(self.dataset, "get_atomref"): return self.dataset.get_atomref() return None @property def mean(self): + """Returns the mean of the dataset if it has one, otherwise None.""" return self._mean @property def std(self): + """Returns the standard deviation of the dataset if it has one, otherwise None.""" return self._std def _is_test_during_training_epoch(self): diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e2470dcb2..c2a2d1e84 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -8,6 +8,117 @@ class Ace(Dataset): + """The ACE dataset. + + This dataset is sourced from HDF5 files. + + Mandatory HDF5 file attributes: + + - `layout`: Must be set to `Ace`. + - `layout_version`: Can be `1.0` or `2.0`. + - `name`: Name of the dataset. + + For `layout_version` 1.0: + + - Files can contain multiple molecule groups directly under the root. + - Each molecule group contains: + + - `atomic_numbers`: Atomic numbers of the atoms. + - `formal_charges`: Formal charges of the atoms. The sum is the molecule's total charge. Units: electron charges. + - `conformations` subgroup: This subgroup has individual conformation groups, each with datasets for different properties of the conformation. + + For `layout_version` 2.0: + + - Files contain a single root group (e.g., a 'master molecule group'). + - Within this root group, there can be multiple molecule groups. + - Each molecule group contains: + + - `atomic_numbers`: Atomic numbers of the atoms. + - `formal_charges`: Formal charges of the atoms. + - Datasets for multiple conformations directly, without individual conformation groups. + + + Each conformation group (version 1.0) or molecule group (version 2.0) should have the following datasets: + + - `positions`: Atomic positions. Units: Angstrom. + - `forces`: Forces on the atoms. Units: eV/Å. + - `partial_charges`: Atomic partial charges. Units: electron charges. + - `dipole_moment` (version 1.0) or `dipole_moments` (version 2.0): Dipole moment (a vector of three components). Units: e*Å. + - `formation_energy` (version 1.0) or `formation_energies` (version 2.0): Formation energy. Units: eV. + + Each dataset should also have an `units` attribute specifying its units (i.e., `Å`, `eV`, `e*Å`). + + Note that version 2.0 is more efficient than 1.0. + + Args: + root (string, optional): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. + pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. + pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. + paths (string or list): Path to the HDF5 files or directory containing the HDF5 files. + max_gradient (float, optional): Maximum gradient norm. Samples with larger gradients are discarded. + subsample_molecules (int, optional): Subsample molecules. Only every `subsample_molecules` molecule is used. + + Examples:: + >>> import numpy as np + >>> from torchmdnet.datasets import Ace + >>> import h5py + >>> # Version 1.0 example + >>> with h5py.File("molecule.h5", 'w') as f: + ... f.attrs["layout"] = "Ace" + ... f.attrs["layout_version"] = "1.0" + ... f.attrs["name"] = "sample_molecule_data" + ... for m in range(3): # Three molecules + ... mol = f.create_group(f"mol_{m+1}") + ... mol["atomic_numbers"] = [1, 6, 8] # H, C, O + ... mol["formal_charges"] = [0, 0, 0] # Neutral charges + ... confs = mol.create_group("conformations") + ... for i in range(2): # Two conformations + ... conf = confs.create_group(f"conf_{i+1}") + ... conf["positions"] = np.random.random((3, 3)) + ... conf["positions"].attrs["units"] = "Å" + ... conf["formation_energy"] = np.random.random() + ... conf["formation_energy"].attrs["units"] = "eV" + ... conf["forces"] = np.random.random((3, 3)) + ... conf["forces"].attrs["units"] = "eV/Å" + ... conf["partial_charges"] = np.random.random(3) + ... conf["partial_charges"].attrs["units"] = "e" + ... conf["dipole_moment"] = np.random.random(3) + ... conf["dipole_moment"].attrs["units"] = "e*Å" + >>> dataset = Ace(root=".", paths="molecule.h5") + >>> len(dataset) + 6 + >>> dataset = Ace(root=".", paths=["molecule.h5", "molecule.h5"]) + >>> len(dataset) + 12 + >>> # Version 2.0 example + >>> with h5py.File("molecule_v2.h5", 'w') as f: + ... f.attrs["layout"] = "Ace" + ... f.attrs["layout_version"] = "2.0" + ... f.attrs["name"] = "sample_molecule_data_v2" + ... master_mol_group = f.create_group("master_molecule_group") + ... for m in range(3): # Three molecules + ... mol = master_mol_group.create_group(f"mol_{m+1}") + ... mol["atomic_numbers"] = [1, 6, 8] # H, C, O + ... mol["formal_charges"] = [0, 0, 0] # Neutral charges + ... mol["positions"] = np.random.random((2, 3, 3)) # Two conformations + ... mol["positions"].attrs["units"] = "Å" + ... mol["formation_energies"] = np.random.random(2) + ... mol["formation_energies"].attrs["units"] = "eV" + ... mol["forces"] = np.random.random((2, 3, 3)) + ... mol["forces"].attrs["units"] = "eV/Å" + ... mol["partial_charges"] = np.random.random((2, 3)) + ... mol["partial_charges"].attrs["units"] = "e" + ... mol["dipole_moment"] = np.random.random((2, 3)) + ... mol["dipole_moment"].attrs["units"] = "e*Å" + >>> dataset_v2 = Ace(root=".", paths="molecule_v2.h5") + >>> len(dataset_v2) + 6 + >>> dataset_v2 = Ace(root=".", paths=["molecule_v2.h5", "molecule_v2.h5"]) + >>> len(dataset_v2) + 12 + """ + def __init__( self, root=None, @@ -305,7 +416,24 @@ def len(self): return len(self.y_mm) def get(self, idx): + """Gets the data object at index :obj:`idx`. + + The data object contains the following attributes: + + - :obj:`z`: Atomic numbers of the atoms. + - :obj:`pos`: Positions of the atoms. + - :obj:`y`: Formation energy of the molecule. + - :obj:`neg_dy`: Forces on the atoms. + - :obj:`q`: Total charge of the molecule. + - :obj:`pq`: Partial charges of the atoms. + - :obj:`dp`: Dipole moment of the molecule. + + Args: + idx (int): Index of the data object. + Returns: + :obj:`torch_geometric.data.Data`: The data object. + """ atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32) diff --git a/torchmdnet/datasets/ani.py b/torchmdnet/datasets/ani.py index febfa7168..ed62dd327 100644 --- a/torchmdnet/datasets/ani.py +++ b/torchmdnet/datasets/ani.py @@ -8,8 +8,29 @@ class ANIBase(Dataset): + """ANI Dataset Classes - HARTREE_TO_EV = 27.211386246 + A foundational dataset class for handling the ANI datasets. ANI (ANAKIN-ME or Accurate NeurAl networK engINe for Molecular Energies) + is a deep learning method trained on quantum mechanical DFT calculations to predict accurate and transferable potentials for organic molecules. + + Key features of ANI: + + - Utilizes a modified version of the Behler and Parrinello symmetry functions to construct single-atom atomic environment vectors (AEV) for molecular representation. + - AEVs enable the training of neural networks over both configurational and conformational space. + - The ANI-1 potential was trained on a subset of the GDB databases with up to 8 heavy atoms. + - ANI-1x and ANI-1ccx datasets provide diverse quantum mechanical properties for organic molecules: + - ANI-1x contains multiple QM properties from 5M density functional theory calculations. + - ANI-1ccx contains 500k data points obtained with an accurate CCSD(T)/CBS extrapolation. + - Properties include energies, atomic forces, multipole moments, atomic charges, and more for the chemical elements C, H, N, and O. + - Developed through active learning, an automated data diversification process. + + References: + + - Smith, J. S., Isayev, O., & Roitberg, A. E. (2017). ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost. Chemical Science, 8(4), 3192-3203. + - Smith, J. S., Zubatyuk, R., Nebgen, B., Lubbers, N., Barros, K., Roitberg, A. E., Isayev, O., & Tretiak, S. (2020). The ANI-1ccx and ANI-1x data sets, coupled-cluster and density functional theory properties for molecules. Scientific Data, 7, Article 134. + """ + + HARTREE_TO_EV = 27.211386246 #::meta private: @property def raw_url(self): @@ -21,7 +42,7 @@ def raw_file_names(self): def compute_reference_energy(self, atomic_numbers): atomic_numbers = np.array(atomic_numbers) - energy = sum(self.ELEMENT_ENERGIES[z] for z in atomic_numbers) + energy = sum(self._ELEMENT_ENERGIES[z] for z in atomic_numbers) return energy * ANIBase.HARTREE_TO_EV def sample_iter(self, mol_ids=False): @@ -148,7 +169,21 @@ def len(self): return len(self.y_mm) def get(self, idx): + """Get a single sample from the dataset. + + Data object contains the following attributes by default: + + - :obj:`z` (:class:`torch.LongTensor`): Atomic numbers of shape :obj:`[num_nodes]`. + - :obj:`pos` (:class:`torch.FloatTensor`): Atomic positions of shape :obj:`[num_nodes, 3]`. + - :obj:`y` (:class:`torch.FloatTensor`): Energies of shape :obj:`[1, 1]`. + - :obj:`neg_dy` (:class:`torch.FloatTensor`, *optional*): Negative gradients of shape :obj:`[num_nodes, 3]`. + Args: + idx (int): Index of the sample. + + Returns: + :class:`torch_geometric.data.Data`: The data object. + """ atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32) @@ -165,12 +200,14 @@ def get(self, idx): class ANI1(ANIBase): + __doc__ = ANIBase.__doc__ + # Avoid sphinx from documenting this ELEMENT_ENERGIES = { 1: -0.500607632585, 6: -37.8302333826, 7: -54.5680045287, 8: -75.0362229210, - } + } #::meta private: @property def raw_url(self): @@ -219,7 +256,8 @@ def sample_iter(self, mol_ids=False): yield data def get_atomref(self, max_z=100): - + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + """ refs = pt.zeros(max_z) refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C @@ -235,6 +273,7 @@ def process(self): class ANI1XBase(ANIBase): + @property def raw_url(self): return "https://figshare.com/ndownloader/files/18112775" @@ -249,7 +288,8 @@ def download(self): os.rename(file, self.raw_paths[0]) def get_atomref(self, max_z=100): - + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + """ warnings.warn("Atomic references from the ANI-1 dataset are used!") refs = pt.zeros(max_z) @@ -262,13 +302,17 @@ def get_atomref(self, max_z=100): class ANI1X(ANI1XBase): + __doc__ = ANIBase.__doc__ ELEMENT_ENERGIES = { 1: -0.500607632585, 6: -37.8302333826, 7: -54.5680045287, 8: -75.0362229210, } + """ + :meta private: + """ def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 @@ -319,7 +363,7 @@ def process(self): class ANI1CCX(ANI1XBase): - + __doc__ = ANIBase.__doc__ def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 diff --git a/torchmdnet/datasets/qm9q.py b/torchmdnet/datasets/qm9q.py index 431964b69..b4d67c20b 100644 --- a/torchmdnet/datasets/qm9q.py +++ b/torchmdnet/datasets/qm9q.py @@ -8,9 +8,10 @@ class QM9q(Dataset): - HARTREE_TO_EV = 27.211386246 - BORH_TO_ANGSTROM = 0.529177 - DEBYE_TO_EANG = 0.2081943 # Debey -> e*A + + HARTREE_TO_EV = 27.211386246#::meta private: + BORH_TO_ANGSTROM = 0.529177 #::meta private: + DEBYE_TO_EANG = 0.2081943 #::meta private: Debey -> e*A # Ion energies of elements ELEMENT_ENERGIES = { @@ -19,13 +20,13 @@ class QM9q(Dataset): 7: {-1: -54.4626446440, 0: -54.5269367415, 1: -53.9895574739}, 8: {-1: -74.9699154500, 0: -74.9812632126, 1: -74.4776884006}, 9: {-1: -99.6695561536, 0: -99.6185158728}, - } + } #::meta private: # Select an ion with the lowest energy for each element INITIAL_CHARGES = { element: sorted(zip(charges.values(), charges.keys()))[0][1] for element, charges in ELEMENT_ENERGIES.items() - } + } #::meta private: def __init__( self, diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py index 8b3c8fa1e..2ba936c1b 100644 --- a/torchmdnet/extensions/__init__.py +++ b/torchmdnet/extensions/__init__.py @@ -4,38 +4,112 @@ import os.path as osp import torch import importlib.machinery +from typing import Tuple + def _load_library(library): - """ Load a dynamic library containing torch extensions from the given path. + """Load a dynamic library containing torch extensions from the given path. Args: library (str): The name of the library to load. """ # Find the specification for the library - spec = importlib.machinery.PathFinder().find_spec( - library, [osp.dirname(__file__)] - ) + spec = importlib.machinery.PathFinder().find_spec(library, [osp.dirname(__file__)]) # Check if the specification is found and load the library if spec is not None: torch.ops.load_library(spec.origin) else: - raise ImportError(f"Could not find module '{library}' in {osp.dirname(__file__)}") + raise ImportError( + f"Could not find module '{library}' in {osp.dirname(__file__)}" + ) + _load_library("torchmdnet_extensions") -@torch.jit.script + def is_current_stream_capturing(): - """ - Returns True if the current CUDA stream is capturing. + """Returns True if the current CUDA stream is capturing. + Returns False if CUDA is not available or the current stream is not capturing. This utility is required because the builtin torch function that does this is not scriptable. """ - _is_current_stream_capturing = torch.ops.torchmdnet_extensions.is_current_stream_capturing + _is_current_stream_capturing = ( + torch.ops.torchmdnet_extensions.is_current_stream_capturing + ) return _is_current_stream_capturing() -get_neighbor_pairs_kernel = torch.ops.torchmdnet_extensions.get_neighbor_pairs + +def get_neighbor_pairs_kernel( + strategy: str, + positions: torch.Tensor, + batch: torch.Tensor, + box_vectors: torch.Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes the neighbor pairs for a given set of atomic positions. + + The list is generated as a list of pairs (i,j) without any enforced ordering. + The list is padded with -1 to the maximum number of pairs. + + Parameters + ---------- + strategy : str + Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`. + positions : torch.Tensor + A tensor with shape (N, 3) representing the atomic positions. + batch : torch.Tensor + A tensor with shape (N,). Specifies the batch for each atom. + box_vectors : torch.Tensor + The vectors defining the periodic box with shape `(3, 3)`. + use_periodic : bool + Whether to apply periodic boundary conditions. + cutoff_lower : float + Lower cutoff for the neighbor list. + cutoff_upper : float + Upper cutoff for the neighbor list. + max_num_pairs : int + Maximum number of pairs to store. + loop : bool + Whether to include self-interactions. + include_transpose : bool + Whether to include the transpose of the neighbor list (pair i,j and pair j,i). + + Returns + ------- + neighbors : torch.Tensor + List of neighbors for each atom. Shape (2, max_num_pairs). + distances : torch.Tensor + List of distances for each atom. Shape (max_num_pairs,). + distance_vecs : torch.Tensor + List of distance vectors for each atom. Shape (max_num_pairs, 3). + num_pairs : torch.Tensor + The number of pairs found. + + Notes + ----- + This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`. + """ + return torch.ops.torchmdnet_extensions.get_neighbor_pairs( + strategy, + positions, + batch, + box_vectors, + use_periodic, + cutoff_lower, + cutoff_upper, + max_num_pairs, + loop, + include_transpose, + ) + # For some unknown reason torch.compile is not able to compile this function -if int(torch.__version__.split('.')[0]) >= 2: +if int(torch.__version__.split(".")[0]) >= 2: import torch._dynamo as dynamo + dynamo.disallow_in_graph(get_neighbor_pairs_kernel) diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda.cu b/torchmdnet/extensions/neighbors/neighbors_cuda.cu index 5d892e4b0..3391ade72 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda.cu +++ b/torchmdnet/extensions/neighbors/neighbors_cuda.cu @@ -1,6 +1,5 @@ /* Raul P. Pelaez 2023 Connection between the neighbor CUDA implementations and the torch extension. - See neighbors.cpp for the definition of the torch extension functions. */ #include "neighbors_cuda_brute.cuh" #include "neighbors_cuda_cell.cuh" diff --git a/torchmdnet/models/__init__.py b/torchmdnet/models/__init__.py index 82923d1d8..5f3e44053 100644 --- a/torchmdnet/models/__init__.py +++ b/torchmdnet/models/__init__.py @@ -1 +1 @@ -__all__ = ["graph-network", "transformer", "equivariant-transformer", "tensornet"] +__all_models__ = ["graph-network", "transformer", "equivariant-transformer", "tensornet"] diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 7ee44e361..f0d127b18 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -13,15 +13,16 @@ def create_model(args, prior_model=None, mean=None, std=None): """Create a model from the given arguments. - See :func:`get_args` in scripts/train.py for a description of the arguments. - Parameters - ---------- + + Run `torchmd-train --help` for a description of the arguments. + + Args: args (dict): Arguments for the model. prior_model (nn.Module, optional): Prior model to use. Defaults to None. mean (torch.Tensor, optional): Mean of the training data. Defaults to None. std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None. - Returns - ------- + + Returns: nn.Module: An instance of the TorchMD_Net model. """ dtype = dtype_mapping[args["precision"]] @@ -36,7 +37,7 @@ def create_model(args, prior_model=None, mean=None, std=None): cutoff_upper=args["cutoff_upper"], max_z=args["max_z"], max_num_neighbors=args["max_num_neighbors"], - dtype=dtype + dtype=dtype, ) # representation network @@ -48,7 +49,7 @@ def create_model(args, prior_model=None, mean=None, std=None): num_filters=args["embedding_dimension"], aggr=args["aggr"], neighbor_embedding=args["neighbor_embedding"], - **shared_args + **shared_args, ) elif args["model"] == "transformer": from torchmdnet.models.torchmd_t import TorchMD_T @@ -75,10 +76,10 @@ def create_model(args, prior_model=None, mean=None, std=None): elif args["model"] == "tensornet": from torchmdnet.models.tensornet import TensorNet - # Setting is_equivariant to False to enforce the use of Scalar output module instead of EquivariantScalar + # Setting is_equivariant to False to enforce the use of Scalar output module instead of EquivariantScalar is_equivariant = False representation_model = TensorNet( - equivariance_invariance_group=args["equivariance_invariance_group"], + equivariance_invariance_group=args["equivariance_invariance_group"], **shared_args, ) else: @@ -118,6 +119,18 @@ def create_model(args, prior_model=None, mean=None, std=None): def load_model(filepath, args=None, device="cpu", **kwargs): + """Load a model from a checkpoint file. + + Args: + filepath (str): Path to the checkpoint file. + args (dict, optional): Arguments for the model. Defaults to None. + device (str, optional): Device on which the model should be loaded. Defaults to "cpu". + **kwargs: Extra keyword arguments for the model. + + Returns: + nn.Module: An instance of the TorchMD_Net model. + """ + ckpt = torch.load(filepath, map_location="cpu") if args is None: args = ckpt["hyper_parameters"] @@ -132,12 +145,16 @@ def load_model(filepath, args=None, device="cpu", **kwargs): state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()} # The following are for backward compatibility with models created when atomref was # the only supported prior. - if 'prior_model.initial_atomref' in state_dict: - state_dict['prior_model.0.initial_atomref'] = state_dict['prior_model.initial_atomref'] - del state_dict['prior_model.initial_atomref'] - if 'prior_model.atomref.weight' in state_dict: - state_dict['prior_model.0.atomref.weight'] = state_dict['prior_model.atomref.weight'] - del state_dict['prior_model.atomref.weight'] + if "prior_model.initial_atomref" in state_dict: + state_dict["prior_model.0.initial_atomref"] = state_dict[ + "prior_model.initial_atomref" + ] + del state_dict["prior_model.initial_atomref"] + if "prior_model.atomref.weight" in state_dict: + state_dict["prior_model.0.atomref.weight"] = state_dict[ + "prior_model.atomref.weight" + ] + del state_dict["prior_model.atomref.weight"] model.load_state_dict(state_dict) return model.to(device) @@ -145,8 +162,8 @@ def load_model(filepath, args=None, device="cpu", **kwargs): def create_prior_models(args, dataset=None): """Parse the prior_model configuration option and create the prior models.""" prior_models = [] - if args['prior_model']: - prior_model = args['prior_model'] + if args["prior_model"]: + prior_model = args["prior_model"] prior_names = [] prior_args = [] if not isinstance(prior_model, list): @@ -162,8 +179,8 @@ def create_prior_models(args, dataset=None): else: prior_names.append(prior) prior_args.append({}) - if 'prior_args' in args: - prior_args = args['prior_args'] + if "prior_args" in args: + prior_args = args["prior_args"] if not isinstance(prior_args, list): prior_args = [prior_args] for name, arg in zip(prior_names, prior_args): @@ -177,15 +194,38 @@ def create_prior_models(args, dataset=None): class TorchMD_Net(nn.Module): - """The TorchMD_Net class combines a given representation model - (such as the equivariant transformer), an output model (such as - the scalar output module) and a prior model (such as the atomref - prior), producing a Module that takes as input a series of atoms - features and outputs a scalar value (i.e energy for each - batch/molecule) and, derivative is True, the negative of its derivative - with respect to the positions (i.e forces for each atom). + """ The main TorchMD-Net model. + + The TorchMD_Net class combines a given representation model (such as the equivariant transformer), + an output model (such as the scalar output module), and a prior model (such as the atomref prior). + It produces a Module that takes as input a series of atom features and outputs a scalar value + (i.e., energy for each batch/molecule). If `derivative` is True, it also outputs the negative of + its derivative with respect to the positions (i.e., forces for each atom). + + Parameters + ---------- + representation_model : nn.Module + A model that takes as input the atomic numbers, positions, batch indices, and optionally + charges and spins. It must return a tuple of the form (x, v, z, pos, batch), where x + are the atom features, v are the vector features (if any), z are the atomic numbers, + pos are the positions, and batch are the batch indices. See TorchMD_ET for more details. + output_model : nn.Module + A model that takes as input the atom features, vector features (if any), atomic numbers, + positions, and batch indices. See OutputModel for more details. + prior_model : nn.Module, optional + A model that takes as input the atom features, atomic numbers, positions, and batch + indices. See BasePrior for more details. Defaults to None. + mean : torch.Tensor, optional + Mean of the training data. Defaults to None. + std : torch.Tensor, optional + Standard deviation of the training data. Defaults to None. + derivative : bool, optional + Whether to compute the derivative of the outputs via backpropagation. Defaults to False. + dtype : torch.dtype, optional + Data type of the model. Defaults to torch.float32. """ + def __init__( self, representation_model, @@ -210,7 +250,11 @@ def __init__( ) if isinstance(prior_model, priors.base.BasePrior): prior_model = [prior_model] - self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype) + self.prior_model = ( + None + if prior_model is None + else torch.nn.ModuleList(prior_model).to(dtype=dtype) + ) self.derivative = derivative @@ -235,18 +279,22 @@ def forward( batch: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, - extra_args: Optional[Dict[str, Tensor]] = None + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: - """Compute the output of the model. - Args: - z (Tensor): Atomic numbers of the atoms in the molecule. Shape (N,). - pos (Tensor): Atomic positions in the molecule. Shape (N, 3). - batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape (N,). - q (Tensor, optional): Atomic charges in the molecule. Shape (N,). - s (Tensor, optional): Atomic spins in the molecule. Shape (N,). - extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. """ + Compute the output of the model. + Args: + z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,). + pos (Tensor): Atomic positions in the molecule. Shape: (N, 3). + batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,). + q (Tensor, optional): Atomic charges in the molecule. Shape: (N,). + s (Tensor, optional): Atomic spins in the molecule. Shape: (N,). + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. + + Returns: + Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise. + """ assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index 2283ef279..b906fca1b 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -11,6 +11,11 @@ class OutputModel(nn.Module, metaclass=ABCMeta): + """Base class for output models. + + Derive this class to make custom output models. + As an example, have a look at the :py:mod:`torchmdnet.output_modules.Scalar` output model. + """ def __init__(self, allow_prior_model, reduce_op): super(OutputModel, self).__init__() self.allow_prior_model = allow_prior_model diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index b277a8186..2e98bcc56 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -8,10 +8,12 @@ act_class_mapping, ) +__all__ = ["TensorNet"] torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True -# Creates a skew-symmetric tensor from a vector + def vector_to_skewtensor(vector): + """Creates a skew-symmetric tensor from a vector.""" batch_size = vector.size(0) zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype) tensor = torch.stack( @@ -31,9 +33,8 @@ def vector_to_skewtensor(vector): tensor = tensor.view(-1, 3, 3) return tensor.squeeze(0) - -# Creates a symmetric traceless tensor from the outer product of a vector with itself def vector_to_symtensor(vector): + """Creates a symmetric traceless tensor from the outer product of a vector with itself.""" tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ ..., None, None @@ -41,9 +42,8 @@ def vector_to_symtensor(vector): S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return S - -# Full tensor decomposition into irreducible components def decompose_tensor(tensor): + """Full tensor decomposition into irreducible components.""" I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ ..., None, None ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) @@ -51,15 +51,13 @@ def decompose_tensor(tensor): S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return I, A, S - -# Computes Frobenius norm def tensor_norm(tensor): + """Computes Frobenius norm.""" return (tensor**2).sum((-2, -1)) - class TensorNet(nn.Module): - r"""TensorNet's architecture, from TensorNet: Cartesian Tensor Representations - for Efficient Learning of Molecular Potentials; G. Simeon and G. de Fabritiis. + r"""TensorNet's architecture. + From TensorNet: Cartesian Tensor Representations for Efficient Learning of Molecular Potentials; G. Simeon and G. de Fabritiis. Args: hidden_channels (int, optional): Hidden embedding size. @@ -236,6 +234,10 @@ def forward( class TensorEmbedding(nn.Module): + """Tensor embedding layer. + + :meta private: + """ def __init__( self, hidden_channels, @@ -352,9 +354,8 @@ def forward( return X -def tensor_message_passing( - edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int -) -> Tensor: +def tensor_message_passing(edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int) -> Tensor: + """Message passing for tensors.""" msg = factor * tensor.index_select(0, edge_index[1]) shape = (natoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]) tensor_m = torch.zeros(*shape, device=tensor.device, dtype=tensor.dtype) @@ -363,6 +364,10 @@ def tensor_message_passing( class Interaction(nn.Module): + """Interaction layer. + + :meta private: + """ def __init__( self, num_rbf, diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 8abca54ae..7dc2a5b7c 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -205,6 +205,10 @@ def __repr__(self): class EquivariantMultiHeadAttention(nn.Module): + """Equivariant multi-head attention layer. + + :meta private: + """ def __init__( self, hidden_channels, diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 5d8c5a75d..216ff563b 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -1,7 +1,6 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn -from torch_geometric.nn import MessagePassing from torchmdnet.models.utils import ( NeighborEmbedding, CosineCutoff, @@ -11,7 +10,6 @@ scatter, ) - class TorchMD_GN(nn.Module): r"""The TorchMD Graph Network architecture. Code adapted from https://github.com/rusty1s/pytorch_geometric/blob/d7d8e5e2edada182d820bbb1eec5f016f50db1e0/torch_geometric/nn/models/schnet.py#L38 @@ -193,6 +191,10 @@ def __repr__(self): class InteractionBlock(nn.Module): + """Interaction block for the TorchMD Graph Network architecture. + + :meta private: + """ def __init__( self, hidden_channels, @@ -249,6 +251,10 @@ def forward( class CFConv(nn.Module): + """Continuous-filter convolution layer for the TorchMD Graph Network architecture. + + :meta private: + """ def __init__( self, in_channels, diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 71d4d8f18..567482b66 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -11,6 +11,7 @@ ) + class TorchMD_T(nn.Module): r"""The TorchMD Transformer architecture. @@ -188,6 +189,11 @@ def __repr__(self): class MultiHeadAttention(nn.Module): + """Multi-head attention layer. + + :meta private: + """ + def __init__( self, hidden_channels, diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index bbbea1bf2..35700467a 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -114,34 +114,24 @@ def forward( class OptimizedDistance(torch.nn.Module): - def __init__( - self, - cutoff_lower=0.0, - cutoff_upper=5.0, - max_num_pairs=-32, - return_vecs=False, - loop=False, - strategy="brute", - include_transpose=True, - resize_to_fit=True, - check_errors=True, - box=None, - long_edge_index=True, - ): - super(OptimizedDistance, self).__init__() - """ Compute the neighbor list for a given cutoff. + """ Compute the neighbor list for a given cutoff. + This operation can be placed inside a CUDA graph in some cases. In particular, resize_to_fit and check_errors must be False. - Note that this module returns neighbors such that distance(i,j) >= cutoff_lower and distance(i,j) < cutoff_upper. + + Note that this module returns neighbors such that :math:`r_{ij} \\ge \\text{cutoff_lower}\\quad\\text{and}\\quad r_{ij} < \\text{cutoff_upper}`. + This function optionally supports periodic boundary conditions with arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy certain requirements: - `a[1] = a[2] = b[2] = 0` - `a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff` - `a[0] >= 2*b[0]` - `a[0] >= 2*c[0]` - `b[1] >= 2*c[1]` + .. code:: python + + a[1] = a[2] = b[2] = 0 + a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff + a[0] >= 2*b[0] + a[0] >= 2*c[0] + b[1] >= 2*c[1] These requirements correspond to a particular rotation of the system and reduced form of the vectors, as well as the requirement that the cutoff be @@ -158,11 +148,11 @@ def __init__( If the number of pairs found is larger than this, the pairs are randomly sampled. When check_errors is True, an exception is raised in this case. If negative, it is interpreted as (minus) the maximum number of neighbors per atom. strategy : str - Strategy to use for computing the neighbor list. Can be one of - ["shared", "brute", "cell"]. - Shared: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles. - Brute: A brute force O(N^2) algorithm, best for small number of particles. - Cell: A cell list algorithm, best for large number of particles, low cutoffs and low batch size. + Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`. + + 1. *Shared*: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles. + 2. *Brute*: A brute force O(N^2) algorithm, best for small number of particles. + 3. *Cell*: A cell list algorithm, best for large number of particles, low cutoffs and low batch size. box : torch.Tensor, optional The vectors defining the periodic box. This must have shape `(3, 3)`, where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. @@ -187,6 +177,21 @@ def __init__( Whether to return edge_index as int64, otherwise int32. Default: True """ + def __init__( + self, + cutoff_lower=0.0, + cutoff_upper=5.0, + max_num_pairs=-32, + return_vecs=False, + loop=False, + strategy="brute", + include_transpose=True, + resize_to_fit=True, + check_errors=True, + box=None, + long_edge_index=True + ): + super(OptimizedDistance, self).__init__() self.cutoff_upper = cutoff_upper self.cutoff_lower = cutoff_lower self.max_num_pairs = max_num_pairs @@ -211,28 +216,32 @@ def __init__( def forward( self, pos: Tensor, batch: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: - """Compute the neighbor list for a given cutoff. + """ + Compute the neighbor list for a given cutoff. + Parameters ---------- pos : torch.Tensor - shape (N, 3) - batch : torch.Tensor or None - shape (N,) + A tensor with shape (N, 3) representing the positions. + batch : torch.Tensor, optional + A tensor with shape (N,). Defaults to None. + Returns ------- edge_index : torch.Tensor - List of neighbors for each atom in the batch. - shape (2, num_found_pairs or max_num_pairs) + List of neighbors for each atom in the batch. + Shape is (2, num_found_pairs) or (2, max_num_pairs). edge_weight : torch.Tensor List of distances for each atom in the batch. - shape (num_found_pairs or max_num_pairs,) - edge_vec : torch.Tensor + Shape is (num_found_pairs,) or (max_num_pairs,). + edge_vec : torch.Tensor, optional List of distance vectors for each atom in the batch. - shape (num_found_pairs or max_num_pairs, 3) - - If resize_to_fit is True, the tensors will be trimmed to the actual number of pairs found. - otherwise the tensors will have size max_num_pairs, with neighbor pairs (-1, -1) at the end. + Shape is (num_found_pairs, 3) or (max_num_pairs, 3). + Notes + ----- + If `resize_to_fit` is True, the tensors will be trimmed to the actual number of pairs found. + Otherwise, the tensors will have size `max_num_pairs`, with neighbor pairs (-1, -1) at the end. """ self.box = self.box.to(pos.dtype) max_pairs: int = self.max_num_pairs diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index 22e31f67c..c714d102f 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -39,4 +39,11 @@ def get_init_args(self): return dict(max_z=self.initial_atomref.size(0)) def pre_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Tensor, extra_args: Optional[Dict[str, Tensor]]): + """Adds the stored atomref to the input as: + + .. math:: + + x' = x + \\textrm{atomref}(z) + + """ return x + self.atomref(z) diff --git a/torchmdnet/priors/coulomb.py b/torchmdnet/priors/coulomb.py index faa47717f..ab2f618ad 100644 --- a/torchmdnet/priors/coulomb.py +++ b/torchmdnet/priors/coulomb.py @@ -4,14 +4,26 @@ from typing import Optional, Dict class Coulomb(BasePrior): - """This class implements a Coulomb potential, scaled by erf(alpha*r) to reduce its + """This class implements a Coulomb potential, scaled by :math:`\\textrm{erf}(\\textrm{alpha}*r)` to reduce its effect at short distances. - To use this prior, the Dataset must include a field called `partial_charges` with each sample, - containing the partial charge for each atom. It also must provide the following attributes. + Parameters + ---------- + alpha : float + Scaling factor for the error function. + max_num_neighbors : int + Maximum number of neighbors per atom allowed. + distance_scale : float, optional + Factor to multiply with coordinates in the dataset to convert them to meters. + energy_scale : float, optional + Factor to multiply with energies in the dataset to convert them to Joules (*not* J/mol). + dataset : Dataset + Dataset object. - distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters - energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol) + Notes + ----- + The Dataset used with this class must include a `partial_charges` field for each sample, and provide + `distance_scale` and `energy_scale` attributes if they are not explicitly passed as arguments. """ def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=None, dataset=None): super(Coulomb, self).__init__() diff --git a/torchmdnet/priors/d2.py b/torchmdnet/priors/d2.py index e1aa563fe..7585e7523 100644 --- a/torchmdnet/priors/d2.py +++ b/torchmdnet/priors/d2.py @@ -4,35 +4,31 @@ class D2(BasePrior): - """Dispersive correction term as used in DFT-D2 - - Reference - --------- - Grimme, Stefan. "Semiempirical GGA-type density functional constructed with a long‐range dispersion correction." Journal of computational chemistry 27.15 (2006): 1787-1799. - https://onlinelibrary.wiley.com/doi/10.1002/jcc.20495 - - Arguments - --------- - cutoff_distance: float - Distance cutoff for the correction term - max_num_neighbors: int - Maximum number of neighbors - atomic_number: list of ints or None - Map of atom types to atomic numbers. - If `atomic_numbers=None`, use `dataset.atomic_numbers` - position_scale: float or None - Multiply by this factor to convert positions stored in the dataset to meters (m). - If `position_scale=None` (default), use `dataset.position_scale` - energy_scale: float or None - Multiply by this factor to convert energies stored in the dataset to Joules (J). - Note: *not* J/mol. - If `energy_scale=None` (default), use `dataset.energy_scale` - dataset: Dataset or None - Dataset object. - If `dataset=None`; `atomic_number`, `position_scale`, and `energy_scale` have to be set. - - Example - ------- + """ + Dispersive correction term as used in DFT-D2. + + Reference: + Grimme, Stefan. "Semiempirical GGA-type density functional constructed with a long‐range dispersion correction." + Journal of computational chemistry 27.15 (2006): 1787-1799. + Available at: https://onlinelibrary.wiley.com/doi/10.1002/jcc.20495 + + Parameters + ---------- + cutoff_distance : float + Distance cutoff for the correction term. + max_num_neighbors : int + Maximum number of neighbors to consider. + atomic_number : list of int, optional + Map of atom types to atomic numbers. If None, use `dataset.atomic_numbers`. + position_scale : float, optional + Factor to convert positions stored in the dataset to meters (m). If None (default), use `dataset.position_scale`. + energy_scale : float, optional + Factor to convert energies stored in the dataset to Joules (J). Note: not J/mol. If None (default), use `dataset.energy_scale`. + dataset : Dataset, optional + Dataset object. If None, `atomic_number`, `position_scale`, and `energy_scale` must be explicitly set. + + Examples + -------- >>> from torchmdnet.priors import D2 >>> prior = D2( cutoff_distance=10.0, # Å @@ -104,7 +100,7 @@ class D2(BasePrior): [29.99, 1.881], # 54 Xe ], dtype=pt.float64, - ) + ) #::meta private: C_6_R_r[:, 1] *= 0.1 # Å --> nm def __init__( diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py index a5eeb45f1..038379887 100644 --- a/torchmdnet/priors/zbl.py +++ b/torchmdnet/priors/zbl.py @@ -5,16 +5,27 @@ class ZBL(BasePrior): - """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. - Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It - is an empirical potential that does a good job of describing the repulsion between atoms at very short - distances. + """ + Implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. + + This potential is described in Ziegler, J.F., Biersack, J.P., Littmark, U. "The Stopping and Range of Ions in Solids." + (1985), specifically in equations 9 and 10 on page 147. It is an empirical potential effectively describing the + repulsion between atoms at very short distances. + + Reference: + Available at: https://doi.org/10.1007/978-3-642-68779-2_5 - To use this prior, the Dataset must provide the following attributes. + Parameters + ---------- + atomic_number : torch.Tensor, optional + A 1D tensor of length max_z. `atomic_number[z]` is the atomic number of atoms with atom type z. If None, use `dataset.atomic_number`. + distance_scale : float, optional + Factor to multiply with coordinates stored in the dataset to convert them to meters. If None, use `dataset.distance_scale`. + energy_scale : float, optional + Factor to multiply with energies stored in the dataset to convert them to Joules (not J/mol). If None, use `dataset.energy_scale`. + dataset : Dataset, optional + Dataset object. If None, `atomic_number`, `distance_scale`, and `energy_scale` must be explicitly set. - atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. - distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters - energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol) """ def __init__( diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index dfdc4ac37..45f8e40bb 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -18,8 +18,7 @@ from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number from lightning_utilities.core.rank_zero import rank_zero_warn - -def get_args(): +def get_argparse(): # fmt: off parser = argparse.ArgumentParser(description='Training') parser.add_argument('--load-model', action=LoadFromCheckpoint, help='Restart training using a model checkpoint') # keep first @@ -65,13 +64,13 @@ def get_args(): parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') # model architecture - parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train') + parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') # architectural args - parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') - parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') + parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') + parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') @@ -105,8 +104,11 @@ def get_args(): parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') # fmt: on + return parser + +def get_args(): - args = parser.parse_args() + args = get_argparse() if args.redirect: sys.stdout = open(os.path.join(args.log_dir, "log"), "w") From 0275e5497ff234b55c3781c270bd0e45df79c0d0 Mon Sep 17 00:00:00 2001 From: Raul Date: Mon, 20 Nov 2023 11:24:01 +0100 Subject: [PATCH 7/7] Add missing parse_args from previous bad merge (#241) * Add missing parse_args from previous bad merge * Test torchmd-train in CI --- .github/workflows/CI.yml | 4 ++++ torchmdnet/scripts/train.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1f9067d1f..bbd8d35a0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -44,3 +44,7 @@ jobs: - name: Run tests run: pytest -v -s + + - name: Test torchmd-train utility + run: torchmd-train --help + diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 45f8e40bb..ca958fd73 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -108,8 +108,8 @@ def get_argparse(): def get_args(): - args = get_argparse() - + parser = get_argparse() + args = parser.parse_args() if args.redirect: sys.stdout = open(os.path.join(args.log_dir, "log"), "w") sys.stderr = sys.stdout