From 9f9d267445396c8d4a7b30a70084a10f278abf2d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 21 May 2024 16:28:43 +0200 Subject: [PATCH] Extract representation model --- torchmdnet/module.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index de312893..e40fafb8 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -48,6 +48,16 @@ def forward(self, data): return data +def extract_representation_model(state_dict): + representation_model = {} + prefix = "model.representation_model." + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix) :] + representation_model[new_key] = value + return representation_model + + class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. @@ -83,9 +93,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): if self.hparams.overwrite_representation is not None: ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") - self.model.representation_model.load_state_dict( - ckpt["representation_model"] - ) + state_dict = extract_representation_model(ckpt["state_dict"]) + self.model.representation_model.load_state_dict(state_dict) if self.hparams.freeze_representation: for p in self.model.representation_model.parameters():