From 792572fcd677432b2cef78ce541dee7b9573c766 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 10 Oct 2023 10:59:19 +0200 Subject: [PATCH 1/4] Log epoch real time in LNNP --- torchmdnet/module.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 9c66c35b2..4eec18b6b 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -2,10 +2,10 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.nn.functional import local_response_norm, mse_loss, l1_loss +from torch.nn.functional import mse_loss, l1_loss from torch import Tensor from typing import Optional, Dict, Tuple - +import time from lightning import LightningModule from torchmdnet.models.model import create_model, load_model @@ -37,6 +37,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): self.losses = None self._reset_losses_dict() + self.tstart = time.time() + def configure_optimizers(self): optimizer = AdamW( self.model.parameters(), @@ -215,6 +217,7 @@ def on_validation_epoch_end(self): "epoch": float(self.current_epoch), "lr": self.trainer.optimizers[0].param_groups[0]["lr"], } + result_dict["time"] = time.time() - self.tstart result_dict.update(self._get_mean_loss_dict_for_type("total")) result_dict.update(self._get_mean_loss_dict_for_type("y")) result_dict.update(self._get_mean_loss_dict_for_type("neg_dy")) @@ -226,6 +229,7 @@ def on_test_epoch_end(self): # Log all test losses if not self.trainer.sanity_checking: result_dict = {} + result_dict["time"] = time.time() - self.tstart result_dict.update(self._get_mean_loss_dict_for_type("total")) result_dict.update(self._get_mean_loss_dict_for_type("y")) result_dict.update(self._get_mean_loss_dict_for_type("neg_dy")) From bd6bcf66f50da4ecbeaeba1e29b19d3dcbdd293f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 10 Oct 2023 18:33:00 +0200 Subject: [PATCH 2/4] Small changes --- torchmdnet/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 4eec18b6b..2ce990f20 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -216,8 +216,8 @@ def on_validation_epoch_end(self): result_dict = { "epoch": float(self.current_epoch), "lr": self.trainer.optimizers[0].param_groups[0]["lr"], + "time": time.time() - self.tstart, } - result_dict["time"] = time.time() - self.tstart result_dict.update(self._get_mean_loss_dict_for_type("total")) result_dict.update(self._get_mean_loss_dict_for_type("y")) result_dict.update(self._get_mean_loss_dict_for_type("neg_dy")) From 6cb018ce4ebbbf7b10648768c10550a3f0350fd4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 19 Jan 2024 11:16:51 +0100 Subject: [PATCH 3/4] Log epoch as an integer --- torchmdnet/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 1c8bfe960..9a9ac22f2 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -222,7 +222,7 @@ def on_validation_epoch_end(self): if not self.trainer.sanity_checking: # construct dict of logged metrics result_dict = { - "epoch": float(self.current_epoch), + "epoch": self.current_epoch, "lr": self.trainer.optimizers[0].param_groups[0]["lr"], "time": time.time() - self.tstart, } From c4da37a7cc76a9ad7dcd2571eaee2ee42237b517 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 19 Jan 2024 12:57:26 +0100 Subject: [PATCH 4/4] Revert "Log epoch as an integer" This reverts commit 6cb018ce4ebbbf7b10648768c10550a3f0350fd4. --- torchmdnet/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 9a9ac22f2..1c8bfe960 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -222,7 +222,7 @@ def on_validation_epoch_end(self): if not self.trainer.sanity_checking: # construct dict of logged metrics result_dict = { - "epoch": self.current_epoch, + "epoch": float(self.current_epoch), "lr": self.trainer.optimizers[0].param_groups[0]["lr"], "time": time.time() - self.tstart, }