diff --git a/environment.yml b/environment.yml index 9392d147..5116787f 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,6 @@ dependencies: - pydantic - torchmetrics - tqdm - - torch-ema # Dev tools - flake8 - pytest diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 390e5908..ea26051d 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -6,6 +6,7 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.swa_utils from torch.nn.functional import local_response_norm, mse_loss, l1_loss from torch import Tensor from typing import Optional, Dict, Tuple @@ -14,7 +15,6 @@ from torchmdnet.models.model import create_model, load_model from torchmdnet.models.utils import dtype_mapping import torch_geometric.transforms as T -from torch_ema import ExponentialMovingAverage class FloatCastDatasetWrapper(T.BaseTransform): @@ -74,19 +74,26 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - self.ema_parameters = None + self.ema_model = None if ( "ema_parameters_decay" in self.hparams and self.hparams.ema_parameters_decay is not None ): - # initialize EMA for the model paremeters - self.ema_parameters = ExponentialMovingAverage( - self.model.parameters(), decay=self.hparams.ema_parameters_decay + self.ema_model = torch.optim.swa_utils.AveragedModel( + self.model, + multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn( + self.hparams.ema_parameters_decay + ), ) + self.ema_parameters_start = ( + self.hparams.ema_parameters_start + if "ema_parameters_start" in self.hparams + else 0 + ) - # initialize exponential smoothing - self.ema = None - self._reset_ema_dict() + # initialize exponential smoothing for the losses + self.ema_loss = None + self._reset_ema_loss_dict() # initialize loss collection self.losses = None @@ -188,12 +195,12 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss): alpha = getattr(self.hparams, f"ema_alpha_{type}") if stage in ["train", "val"] and alpha < 1 and alpha > 0: ema = ( - self.ema[stage][type][loss_name] - if loss_name in self.ema[stage][type] + self.ema_loss[stage][type][loss_name] + if loss_name in self.ema_loss[stage][type] else loss.detach() ) loss = alpha * loss + (1 - alpha) * ema - self.ema[stage][type][loss_name] = loss.detach() + self.ema_loss[stage][type][loss_name] = loss.detach() return loss def step(self, batch, loss_fn_list, stage): @@ -261,13 +268,13 @@ def optimizer_step(self, *args, **kwargs): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.lr super().optimizer_step(*args, **kwargs) + if ( + self.trainer.global_step >= self.ema_parameters_start + and self.ema_model is not None + ): + self.ema_model.update_parameters(self.model) optimizer.zero_grad() - def on_before_zero_grad(self, *args, **kwargs): - if self.ema_parameters is not None: - self.ema_parameters.to(self.device) - self.ema_parameters.update(self.model.parameters()) - def _get_mean_loss_dict_for_type(self, type): # Returns a list with the mean loss for each loss_fn for each stage (train, val, test) # Parameters: @@ -320,9 +327,9 @@ def _reset_losses_dict(self): for loss_type in ["total", "y", "neg_dy"]: self.losses[stage][loss_type] = defaultdict(list) - def _reset_ema_dict(self): - self.ema = {} + def _reset_ema_loss_dict(self): + self.ema_loss = {} for stage in ["train", "val"]: - self.ema[stage] = {} + self.ema_loss[stage] = {} for loss_type in ["y", "neg_dy"]: - self.ema[stage][loss_type] = {} + self.ema_loss[stage][loss_type] = {} diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 44a2e797..aef1b0ec 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -60,6 +60,7 @@ def get_argparse(): parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') parser.add_argument('--ema-parameters-decay', type=float, default=None, help='Exponential moving average decay for model parameters (defaults to None, meaning disable). The decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed.') + parser.add_argument('--ema-parameters-start', type=int, default=0, help='Epoch to start averaging the parameters.') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')