Skip to content

Commit

Permalink
Extract representation model
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed May 21, 2024
1 parent 96c2920 commit 9f9d267
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def forward(self, data):
return data


def extract_representation_model(state_dict):
representation_model = {}
prefix = "model.representation_model."
for key, value in state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix) :]
representation_model[new_key] = value
return representation_model


class LNNP(LightningModule):
"""
Lightning wrapper for the Neural Network Potentials in TorchMD-Net.
Expand Down Expand Up @@ -83,9 +93,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):

if self.hparams.overwrite_representation is not None:
ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu")
self.model.representation_model.load_state_dict(
ckpt["representation_model"]
)
state_dict = extract_representation_model(ckpt["state_dict"])
self.model.representation_model.load_state_dict(state_dict)

if self.hparams.freeze_representation:
for p in self.model.representation_model.parameters():
Expand Down

0 comments on commit 9f9d267

Please sign in to comment.