diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 434973a8..de312893 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -81,13 +81,13 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - if hparams["overwrite_representation"] is not None: - ckpt = torch.load(hparams["overwrite_representation"], map_location="cpu") + if self.hparams.overwrite_representation is not None: + ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") self.model.representation_model.load_state_dict( ckpt["representation_model"] ) - if hparams["freeze_representation"]: + if self.hparams.freeze_representation: for p in self.model.representation_model.parameters(): p.requires_grad = False