diff --git a/torchmdnet/module.py b/torchmdnet/module.py index b3e7a01a..1c8bfe96 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -6,10 +6,10 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.nn.functional import local_response_norm, mse_loss, l1_loss +from torch.nn.functional import mse_loss, l1_loss from torch import Tensor from typing import Optional, Dict, Tuple - +import time from lightning import LightningModule from torchmdnet.models.model import create_model, load_model @@ -41,6 +41,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): self.losses = None self._reset_losses_dict() + self.tstart = time.time() + def configure_optimizers(self): optimizer = AdamW( self.model.parameters(), @@ -222,6 +224,7 @@ def on_validation_epoch_end(self): result_dict = { "epoch": float(self.current_epoch), "lr": self.trainer.optimizers[0].param_groups[0]["lr"], + "time": time.time() - self.tstart, } result_dict.update(self._get_mean_loss_dict_for_type("total")) result_dict.update(self._get_mean_loss_dict_for_type("y")) @@ -234,6 +237,7 @@ def on_test_epoch_end(self): # Log all test losses if not self.trainer.sanity_checking: result_dict = {} + result_dict["time"] = time.time() - self.tstart result_dict.update(self._get_mean_loss_dict_for_type("total")) result_dict.update(self._get_mean_loss_dict_for_type("y")) result_dict.update(self._get_mean_loss_dict_for_type("neg_dy"))