From ffcacbed20f2c76839663aa9bfcd54ee4428c9dc Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 20 Jun 2024 12:49:30 +0200 Subject: [PATCH 1/5] Add second gradient regularization --- torchmdnet/module.py | 35 +++++++++++++++++++++++++++++++++++ torchmdnet/scripts/train.py | 4 ++++ 2 files changed, 39 insertions(+) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index d5ea73cf..e81a4ef0 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -167,6 +167,32 @@ def _compute_losses(self, y, neg_y, batch, loss_fn, stage): loss_y = self._update_loss_with_ema(stage, "y", loss_name, loss_y) return {"y": loss_y, "neg_dy": loss_neg_y} + def _compute_second_derivative_regularization(self, y, neg_dy, batch): + # Compute force gradient and add it to the loss like: max(0, grad(neg_dy.sum())-eps)^2 + # Args: + # y: predicted value + # neg_dy: predicted negative derivative + # batch: batch of data + # Returns: + # regularization: regularization term + assert "pos" in batch + force_sum = neg_dy.sum() + grad_outputs = [torch.ones_like(force_sum)] + assert batch.pos.requires_grad + ddy = torch.autograd.grad( + [force_sum], + [batch.pos], + grad_outputs=grad_outputs, + create_graph=True, + retain_graph=True, + )[0] + decay = self.hparams.regularization_decay / (self.current_epoch + 1) + regularization = ( + torch.max((ddy.norm() - self.hparams.regularization_coefficient), 0)[0] + * decay + ) + return regularization + def _update_loss_with_ema(self, stage, type, loss_name, loss): # Update the loss using an exponential moving average when applicable # Args: @@ -235,6 +261,15 @@ def step(self, batch, loss_fn_list, stage): step_losses["y"] * self.hparams.y_weight + step_losses["neg_dy"] * self.hparams.neg_dy_weight ) + if ( + self.hparams.regularize_second_gradient + and self.hparams.derivative + and stage == "train" + ): + total_loss = ( + total_loss + + self._compute_second_derivative_regularization(y, neg_dy, batch) + ) self.losses[stage]["total"][loss_name].append(total_loss.detach()) return total_loss diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7f2d8e07..39f47566 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -59,6 +59,10 @@ def get_argparse(): parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') + + parser.add_argument('--regularize-second-gradient', action="store_true", help='If true, regularize the second derivative of the energy w.r.t. the coordinates') + parser.add_argument('--regularization-coefficient', type=float, default=0.0, help='Coefficient for the regularization term') + parser.add_argument('--regularization-decay', type=float, default=0.0, help='Decay rate for the regularization term') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') From 0ef4ef8ff12473f500dbe7a573c433c0f29cf819 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 21 Jun 2024 09:50:40 +0200 Subject: [PATCH 2/5] Update --- torchmdnet/module.py | 9 +++------ torchmdnet/scripts/train.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index e81a4ef0..8f663134 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -176,7 +176,7 @@ def _compute_second_derivative_regularization(self, y, neg_dy, batch): # Returns: # regularization: regularization term assert "pos" in batch - force_sum = neg_dy.sum() + force_sum = (neg_dy**2).sum() grad_outputs = [torch.ones_like(force_sum)] assert batch.pos.requires_grad ddy = torch.autograd.grad( @@ -186,11 +186,8 @@ def _compute_second_derivative_regularization(self, y, neg_dy, batch): create_graph=True, retain_graph=True, )[0] - decay = self.hparams.regularization_decay / (self.current_epoch + 1) - regularization = ( - torch.max((ddy.norm() - self.hparams.regularization_coefficient), 0)[0] - * decay - ) + regularization = ddy.norm() * self.hparams.regularization_weight + print(f"Regularization: {regularization}") return regularization def _update_loss_with_ema(self, stage, type, loss_name, loss): diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 39f47566..a220ca7d 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -62,7 +62,7 @@ def get_argparse(): parser.add_argument('--regularize-second-gradient', action="store_true", help='If true, regularize the second derivative of the energy w.r.t. the coordinates') parser.add_argument('--regularization-coefficient', type=float, default=0.0, help='Coefficient for the regularization term') - parser.add_argument('--regularization-decay', type=float, default=0.0, help='Decay rate for the regularization term') + parser.add_argument('--regularization-weight', type=float, default=0.0, help='Weight for the force regularization term') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') From 160b2e369e9d564a47c33aa7881b54013569045a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 21 Jun 2024 11:12:12 +0200 Subject: [PATCH 3/5] Remove unused parameter --- torchmdnet/scripts/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a220ca7d..f4f08058 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -61,7 +61,6 @@ def get_argparse(): parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') parser.add_argument('--regularize-second-gradient', action="store_true", help='If true, regularize the second derivative of the energy w.r.t. the coordinates') - parser.add_argument('--regularization-coefficient', type=float, default=0.0, help='Coefficient for the regularization term') parser.add_argument('--regularization-weight', type=float, default=0.0, help='Weight for the force regularization term') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') From 0c34996874bae90072b3f245f509ce1348da298b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 25 Jun 2024 11:56:34 +0200 Subject: [PATCH 4/5] Add artificial short range dataset --- torchmdnet/datasets/__init__.py | 2 + torchmdnet/datasets/artificial_short_range.py | 72 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 torchmdnet/datasets/artificial_short_range.py diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index b57cd95a..dff30016 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -23,6 +23,7 @@ from .qm9q import QM9q from .spice import SPICE from .genentech import GenentechTorsions +from .artificial_short_range import ShortRange __all__ = [ "Ace", @@ -47,4 +48,5 @@ "SPICE", "Tripeptides", "WaterBox", + "ShortRange", ] diff --git a/torchmdnet/datasets/artificial_short_range.py b/torchmdnet/datasets/artificial_short_range.py new file mode 100644 index 00000000..34b2389c --- /dev/null +++ b/torchmdnet/datasets/artificial_short_range.py @@ -0,0 +1,72 @@ +import glob +import numpy as np +import torch +from torch_geometric.data import Dataset, Data + + +def random_vectors_in_sphere_box_muller(radius, count): + # Generate uniformly distributed random numbers for Box-Muller + u1 = np.random.uniform(low=0.0, high=1.0, size=count) + u2 = np.random.uniform(low=0.0, high=1.0, size=count) + u3 = np.random.uniform(low=0.0, high=1.0, size=count) + + # Box-Muller transform for normal distribution + normal1 = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2) + normal2 = np.sqrt(-2.0 * np.log(u1)) * np.sin(2.0 * np.pi * u2) + normal3 = np.sqrt(-2.0 * np.log(u3)) * np.cos( + 2.0 * np.pi * u2 + ) # Using u2 again for the third component + + # Stack the normals + vectors = np.column_stack((normal1, normal2, normal3)) + + # Normalize each vector to have magnitude 1 + norms = np.linalg.norm(vectors, axis=1) + vectors_normalized = vectors / norms[:, np.newaxis] + + # Scale vectors by random radii up to 'radius' + scale = np.random.uniform(0, radius**3, count) ** ( + 1 / 3 + ) # Cube root to ensure uniform distribution in volume + vectors_scaled = vectors_normalized * scale[:, np.newaxis] + + return vectors_scaled + + +class ShortRange(Dataset): + def __init__(self, max_dist, size, max_z, transform=None, pre_transform=None): + self.max_dist = max_dist + self.size = size + self.max_z = max_z + # Create some npy files with random data. The dataset consists of pairs of atoms, with their positions, atomic numbers and energy + # Positions inside a sphere of radius max_dist + self.pos = random_vectors_in_sphere_box_muller(max_dist, 2 * size) + self.pos = self.pos.reshape(size, 2, 3) + # Atomic numbers + self.z = np.random.randint(1, max_z, size=2 * size).reshape(size, 2) + # Energy, should be a linear function of the distance, goes from 20 to 100 from max_dist to 0 + dist = np.linalg.norm( + self.pos[:, 0, :] - self.pos[:, 1, :], axis=1 + ) # shape (size,) + self.y = 20 + 80 * (1 - dist / max_dist) + # Negative gradient of the energy with respect to the positions, should have the same shape as pos + self.neg_dy = np.zeros((size, 2, 3)) + self.neg_dy[:, 0, :] = ( + -80 + / max_dist + * (self.pos[:, 0, :] - self.pos[:, 1, :]) + / dist[:, np.newaxis] + ) + self.neg_dy[:, 1, :] = -self.neg_dy[:, 0, :] + + def get(self, idx): + data = Data( + z=torch.tensor(self.z[idx], dtype=torch.long), + pos=torch.tensor(self.pos[idx], dtype=torch.float), + y=torch.tensor(self.y[idx], dtype=torch.float), + neg_dy=torch.tensor(self.neg_dy[idx], dtype=torch.float), + ) + return data + + def __len__(self): + return self.size From fbb55ba17813444a1c7e56f3395ced89661d1c76 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 26 Jun 2024 12:54:24 +0200 Subject: [PATCH 5/5] Update --- torchmdnet/datasets/artificial_short_range.py | 81 ++++++++++++++----- 1 file changed, 63 insertions(+), 18 deletions(-) diff --git a/torchmdnet/datasets/artificial_short_range.py b/torchmdnet/datasets/artificial_short_range.py index 34b2389c..38e964f0 100644 --- a/torchmdnet/datasets/artificial_short_range.py +++ b/torchmdnet/datasets/artificial_short_range.py @@ -33,8 +33,24 @@ def random_vectors_in_sphere_box_muller(radius, count): return vectors_scaled +def compute_energy(pos, max_dist): + dist = torch.linalg.norm(pos[:, 0, :] - pos[:, 1, :], axis=1) # shape (size,) + y = 20 + 80 * (1 - dist / max_dist) # shape (size,) + return y + + +def compute_forces(pos, max_dist): + pos = pos.clone().detach().requires_grad_(True) + y = compute_energy(pos, max_dist) + y_sum = y.sum() + y_sum.backward() + forces = -pos.grad + return forces + + class ShortRange(Dataset): - def __init__(self, max_dist, size, max_z, transform=None, pre_transform=None): + def __init__(self, root, max_dist, size, max_z, transform=None, pre_transform=None): + super(ShortRange, self).__init__(root, transform, pre_transform) self.max_dist = max_dist self.size = size self.max_z = max_z @@ -44,29 +60,58 @@ def __init__(self, max_dist, size, max_z, transform=None, pre_transform=None): self.pos = self.pos.reshape(size, 2, 3) # Atomic numbers self.z = np.random.randint(1, max_z, size=2 * size).reshape(size, 2) - # Energy, should be a linear function of the distance, goes from 20 to 100 from max_dist to 0 - dist = np.linalg.norm( - self.pos[:, 0, :] - self.pos[:, 1, :], axis=1 - ) # shape (size,) - self.y = 20 + 80 * (1 - dist / max_dist) + # Energy + self.y = compute_energy(torch.tensor(self.pos), max_dist).detach().numpy() * 0 + assert self.y.shape == (size,) + assert self.z.shape == (size, 2) # Negative gradient of the energy with respect to the positions, should have the same shape as pos - self.neg_dy = np.zeros((size, 2, 3)) - self.neg_dy[:, 0, :] = ( - -80 - / max_dist - * (self.pos[:, 0, :] - self.pos[:, 1, :]) - / dist[:, np.newaxis] + self.neg_dy = ( + compute_forces(torch.tensor(self.pos, dtype=torch.float), max_dist) + .detach() + .numpy() + * 0 ) - self.neg_dy[:, 1, :] = -self.neg_dy[:, 0, :] def get(self, idx): + y = torch.tensor(self.y[idx], dtype=torch.float).view(1, 1) + z = torch.tensor(self.z[idx], dtype=torch.long).view(2) + pos = torch.tensor(self.pos[idx], dtype=torch.float).view(2, 3) + neg_dy = torch.tensor(self.neg_dy[idx], dtype=torch.float).view(2, 3) data = Data( - z=torch.tensor(self.z[idx], dtype=torch.long), - pos=torch.tensor(self.pos[idx], dtype=torch.float), - y=torch.tensor(self.y[idx], dtype=torch.float), - neg_dy=torch.tensor(self.neg_dy[idx], dtype=torch.float), + z=z, + pos=pos, + y=y, + neg_dy=neg_dy, ) return data - def __len__(self): + def len(self): return self.size + + # Taken from https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat + + _ELEMENT_ENERGIES = { + 1: -0.5978583943827134, # H + 6: -38.08933878049795, # C + 7: -54.711968298621066, # N + 8: -75.19106774742086, # O + 9: -99.80348506781634, # F + 16: -398.1577125334925, # S + 17: -460.1681939421027, # Cl + } + HARTREE_TO_EV = 27.211386246 #::meta private: + + def get_atomref(self, max_z=100): + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + + Args: + max_z (int): Maximum atomic number + + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ + refs = torch.zeros(max_z) + for key, val in self._ELEMENT_ENERGIES.items(): + refs[key] = val * self.HARTREE_TO_EV * 0 + + return refs.view(-1, 1)