Skip to content

Commit

Permalink
Fix old model loading (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez authored Apr 22, 2024
1 parent 0ed2e7c commit 552e0ee
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,22 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
# In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias}
# Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias}
# In other models, we had output_model.output_network.{0,1}.{weight,bias},
# which is now output_model.output_network.layers.{0,1}.{weight,bias}
# This change was introduced in https://github.com/torchmd/torchmd-net/pull/314
state_dict = {
re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v
for k, v in state_dict.items()
}
patterns = [
(
r"output_model.output_network.(\d+).update_net.(\d+).",
r"output_model.output_network.\1.update_net.layers.\2.",
),
(
r"output_model.output_network.([02]).(weight|bias)",
r"output_model.output_network.layers.\1.\2",
),
]
for p in patterns:
state_dict = {re.sub(p[0], p[1], k): v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
return model.to(device)

Expand Down

0 comments on commit 552e0ee

Please sign in to comment.