Skip to content

Commit

Permalink
Switch to torch.optim. Add ema_parameters_start to train.py. Rename
Browse files Browse the repository at this point in the history
some variables
  • Loading branch information
RaulPPelaez committed Jul 15, 2024
1 parent dfba2b0 commit de8ffa5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies:
- pydantic
- torchmetrics
- tqdm
- torch-ema
# Dev tools
- flake8
- pytest
Expand Down
47 changes: 27 additions & 20 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim.swa_utils
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
from torch import Tensor
from typing import Optional, Dict, Tuple
Expand All @@ -14,7 +15,6 @@
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models.utils import dtype_mapping
import torch_geometric.transforms as T
from torch_ema import ExponentialMovingAverage


class FloatCastDatasetWrapper(T.BaseTransform):
Expand Down Expand Up @@ -74,19 +74,26 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
else:
self.model = create_model(self.hparams, prior_model, mean, std)

self.ema_parameters = None
self.ema_model = None
if (
"ema_parameters_decay" in self.hparams
and self.hparams.ema_parameters_decay is not None
):
# initialize EMA for the model paremeters
self.ema_parameters = ExponentialMovingAverage(
self.model.parameters(), decay=self.hparams.ema_parameters_decay
self.ema_model = torch.optim.swa_utils.AveragedModel(
self.model,
multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(
self.hparams.ema_parameters_decay
),
)
self.ema_parameters_start = (
self.hparams.ema_parameters_start
if "ema_parameters_start" in self.hparams
else 0
)

# initialize exponential smoothing
self.ema = None
self._reset_ema_dict()
# initialize exponential smoothing for the losses
self.ema_loss = None
self._reset_ema_loss_dict()

# initialize loss collection
self.losses = None
Expand Down Expand Up @@ -188,12 +195,12 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss):
alpha = getattr(self.hparams, f"ema_alpha_{type}")
if stage in ["train", "val"] and alpha < 1 and alpha > 0:
ema = (
self.ema[stage][type][loss_name]
if loss_name in self.ema[stage][type]
self.ema_loss[stage][type][loss_name]
if loss_name in self.ema_loss[stage][type]
else loss.detach()
)
loss = alpha * loss + (1 - alpha) * ema
self.ema[stage][type][loss_name] = loss.detach()
self.ema_loss[stage][type][loss_name] = loss.detach()
return loss

def step(self, batch, loss_fn_list, stage):
Expand Down Expand Up @@ -261,13 +268,13 @@ def optimizer_step(self, *args, **kwargs):
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.hparams.lr
super().optimizer_step(*args, **kwargs)
if (
self.trainer.global_step >= self.ema_parameters_start
and self.ema_model is not None
):
self.ema_model.update_parameters(self.model)
optimizer.zero_grad()

def on_before_zero_grad(self, *args, **kwargs):
if self.ema_parameters is not None:
self.ema_parameters.to(self.device)
self.ema_parameters.update(self.model.parameters())

def _get_mean_loss_dict_for_type(self, type):
# Returns a list with the mean loss for each loss_fn for each stage (train, val, test)
# Parameters:
Expand Down Expand Up @@ -320,9 +327,9 @@ def _reset_losses_dict(self):
for loss_type in ["total", "y", "neg_dy"]:
self.losses[stage][loss_type] = defaultdict(list)

def _reset_ema_dict(self):
self.ema = {}
def _reset_ema_loss_dict(self):
self.ema_loss = {}
for stage in ["train", "val"]:
self.ema[stage] = {}
self.ema_loss[stage] = {}
for loss_type in ["y", "neg_dy"]:
self.ema[stage][loss_type] = {}
self.ema_loss[stage][loss_type] = {}
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_argparse():
parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm')
parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.')
parser.add_argument('--ema-parameters-decay', type=float, default=None, help='Exponential moving average decay for model parameters (defaults to None, meaning disable). The decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed.')
parser.add_argument('--ema-parameters-start', type=int, default=0, help='Epoch to start averaging the parameters.')
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
Expand Down

0 comments on commit de8ffa5

Please sign in to comment.