Skip to content

Commit

Permalink
Fix hparams
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed May 21, 2024
1 parent 97b556b commit 96c2920
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
else:
self.model = create_model(self.hparams, prior_model, mean, std)

if hparams["overwrite_representation"] is not None:
ckpt = torch.load(hparams["overwrite_representation"], map_location="cpu")
if self.hparams.overwrite_representation is not None:
ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu")
self.model.representation_model.load_state_dict(
ckpt["representation_model"]
)

if hparams["freeze_representation"]:
if self.hparams.freeze_representation:
for p in self.model.representation_model.parameters():
p.requires_grad = False

Expand Down

0 comments on commit 96c2920

Please sign in to comment.