diff --git a/torchmdnet/loss.py b/torchmdnet/loss.py new file mode 100644 index 00000000..2f2d8c57 --- /dev/null +++ b/torchmdnet/loss.py @@ -0,0 +1,7 @@ +from torch.nn.functional import mse_loss, l1_loss, huber_loss + +loss_class_mapping = { + "mse_loss": mse_loss, + "l1_loss": l1_loss, + "huber_loss": huber_loss, +} diff --git a/torchmdnet/module.py b/torchmdnet/module.py index d5ea73cf..d0f9d3bd 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -6,13 +6,13 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.nn.functional import local_response_norm, mse_loss, l1_loss +from torch.nn.functional import local_response_norm from torch import Tensor from typing import Optional, Dict, Tuple - from lightning import LightningModule from torchmdnet.models.model import create_model, load_model from torchmdnet.models.utils import dtype_mapping +from torchmdnet.loss import l1_loss, loss_class_mapping import torch_geometric.transforms as T @@ -48,6 +48,18 @@ def __call__(self, data): return data +# This wrapper is here in order to permit Lightning to serialize the loss function. +class LossFunction: + def __init__(self, loss_fn, extra_args=None): + self.loss_fn = loss_fn + self.extra_args = extra_args + if self.extra_args is None: + self.extra_args = {} + + def __call__(self, x, batch): + return self.loss_fn(x, batch, **self.extra_args) + + class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. @@ -65,7 +77,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False - + if "train_loss" not in hparams: + hparams["train_loss"] = "mse_loss" + if "train_loss_arg" not in hparams: + hparams["train_loss_arg"] = {} self.save_hyperparameters(hparams) if self.hparams.load_model: @@ -92,6 +107,16 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): ] ) + if self.hparams.train_loss not in loss_class_mapping: + raise ValueError( + f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}" + ) + + self.train_loss_fn = LossFunction( + loss_class_mapping[self.hparams.train_loss], + self.hparams.train_loss_arg, + ) + def configure_optimizers(self): optimizer = AdamW( self.model.parameters(), @@ -105,9 +130,12 @@ def configure_optimizers(self): patience=self.hparams.lr_patience, min_lr=self.hparams.lr_min, ) + lr_metric = getattr(self.hparams, "lr_metric", "val") + monitor = f"{lr_metric}_total_{self.hparams.train_loss}" lr_scheduler = { "scheduler": scheduler, - "monitor": getattr(self.hparams, "lr_metric", "val_loss"), + "strict": True, + "monitor": monitor, "interval": "epoch", "frequency": 1, } @@ -126,7 +154,9 @@ def forward( return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args) def training_step(self, batch, batch_idx): - return self.step(batch, [mse_loss], "train") + return self.step( + batch, [(self.hparams.train_loss, self.train_loss_fn)], "train" + ) def validation_step(self, batch, batch_idx, *args): # If args is not empty the first (and only) element is the dataloader_idx @@ -135,28 +165,34 @@ def validation_step(self, batch, batch_idx, *args): # The dataloader takes care of sending the two sets only when the second one is needed. is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0) if is_val: - step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"} + step_type = { + "loss_fn_list": [ + ("l1_loss", l1_loss), + (self.hparams.train_loss, self.train_loss_fn), + ], + "stage": "val", + } else: - step_type = {"loss_fn_list": [l1_loss], "stage": "test"} + step_type = {"loss_fn_list": [("l1_loss", l1_loss)], "stage": "test"} return self.step(batch, **step_type) def test_step(self, batch, batch_idx): - return self.step(batch, [l1_loss], "test") + return self.step(batch, [("l1_loss", l1_loss)], "test") - def _compute_losses(self, y, neg_y, batch, loss_fn, stage): + def _compute_losses(self, y, neg_y, batch, loss_fn, loss_name, stage): # Compute the loss for the predicted value and the negative derivative (if available) # Args: # y: predicted value # neg_y: predicted negative derivative # batch: batch of data - # loss_fn: loss function to compute + # loss_fn: The loss function to compute + # loss_name: The name of the loss function # Returns: # loss_y: loss for the predicted value # loss_neg_y: loss for the predicted negative derivative loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor( 0.0, device=self.device ) - loss_name = loss_fn.__name__ if self.hparams.derivative and "neg_dy" in batch: loss_neg_y = loss_fn(neg_y, batch.neg_dy) loss_neg_y = self._update_loss_with_ema( @@ -221,10 +257,10 @@ def step(self, batch, loss_fn_list, stage): neg_dy = neg_dy + y.sum() * 0 if "y" in batch and batch.y.ndim == 1: batch.y = batch.y.unsqueeze(1) - for loss_fn in loss_fn_list: - step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage) - - loss_name = loss_fn.__name__ + for loss_name, loss_fn in loss_fn_list: + step_losses = self._compute_losses( + y, neg_dy, batch, loss_fn, loss_name, stage + ) if self.hparams.neg_dy_weight > 0: self.losses[stage]["neg_dy"][loss_name].append( step_losses["neg_dy"].detach() diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 0951b92d..76f4c63d 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -17,6 +17,7 @@ from torchmdnet.module import LNNP from torchmdnet import datasets, priors, models from torchmdnet.data import DataModule +from torchmdnet.loss import loss_class_mapping from torchmdnet.models import output_modules from torchmdnet.models.model import create_prior_models from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping @@ -34,7 +35,7 @@ def get_argparse(): parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation') - parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate') + parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate') parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop') parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving') parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up') @@ -69,6 +70,8 @@ def get_argparse(): parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB') parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function') parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') + parser.add_argument('--train-loss', default='mse_loss', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training') + parser.add_argument('--train-loss-arg', default=None, help='Additional arguments for the loss function. Needs to be a dictionary.') # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') @@ -165,17 +168,16 @@ def main(): # initialize lightning module model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std) + val_loss_name = f"val_total_{args.train_loss}" checkpoint_callback = ModelCheckpoint( dirpath=args.log_dir, - monitor="val_total_mse_loss", + monitor=val_loss_name, save_top_k=10, # -1 to save all every_n_epochs=args.save_interval, - filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}", + filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}", auto_insert_metric_name=False, ) - early_stopping = EarlyStopping( - "val_total_mse_loss", patience=args.early_stopping_patience - ) + early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience) csv_logger = CSVLogger(args.log_dir, name="", version="") _logger = [csv_logger]