Skip to content

Commit

Permalink
Add Huber loss
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jul 15, 2024
1 parent 6c42c8b commit f6fa219
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
7 changes: 7 additions & 0 deletions torchmdnet/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torch.nn.functional import mse_loss, l1_loss, huber_loss

loss_map = {
"mse_loss": mse_loss,
"l1_loss": l1_loss,
"huber_loss": huber_loss,
}
15 changes: 12 additions & 3 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
from torch.nn.functional import local_response_norm
from torch import Tensor
from typing import Optional, Dict, Tuple

Expand All @@ -16,6 +16,9 @@
import torch_geometric.transforms as T


from torchmdnet.loss import l1_loss, loss_map


class FloatCastDatasetWrapper(T.BaseTransform):
"""A transform that casts all floating point tensors to a given dtype.
tensors to a given dtype.
Expand Down Expand Up @@ -92,6 +95,12 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
]
)

if self.hparams.training_loss not in loss_map:
raise ValueError(
f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_map.keys())}"
)
self.training_loss = loss_map[self.hparams.training_loss]

def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
Expand Down Expand Up @@ -126,7 +135,7 @@ def forward(
return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)

def training_step(self, batch, batch_idx):
return self.step(batch, [mse_loss], "train")
return self.step(batch, [self.training_loss], "train")

def validation_step(self, batch, batch_idx, *args):
# If args is not empty the first (and only) element is the dataloader_idx
Expand All @@ -135,7 +144,7 @@ def validation_step(self, batch, batch_idx, *args):
# The dataloader takes care of sending the two sets only when the second one is needed.
is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0)
if is_val:
step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"}
step_type = {"loss_fn_list": [l1_loss, self.training_loss], "stage": "val"}
else:
step_type = {"loss_fn_list": [l1_loss], "stage": "test"}
return self.step(batch, **step_type)
Expand Down
2 changes: 2 additions & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.loss import loss_map
from torchmdnet.models import output_modules
from torchmdnet.models.model import create_prior_models
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
Expand Down Expand Up @@ -70,6 +71,7 @@ def get_argparse():
parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB')
parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function')
parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function')
parser.add_argument('--train_loss', default='mse', type=str, choices=loss_map.keys(), help='Loss function to use during training')

# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
Expand Down

0 comments on commit f6fa219

Please sign in to comment.