Skip to content

Commit

Permalink
Remove deprecated atomref weight loading
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Feb 13, 2024
1 parent 6d8e315 commit d6bf65e
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,28 +162,6 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
model = create_model(args)

state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
# The following are for backward compatibility with models created when atomref was
# the only supported prior.
if "prior_model.initial_atomref" in state_dict:
warnings.warn(
"prior_model.initial_atomref is deprecated and will be removed in a future version. Use prior_model.0.initial_atomref instead.",
category=DeprecationWarning,
stacklevel=2,
)
state_dict["prior_model.0.initial_atomref"] = state_dict[
"prior_model.initial_atomref"
]
del state_dict["prior_model.initial_atomref"]
if "prior_model.atomref.weight" in state_dict:
warnings.warn(
"prior_model.atomref.weight is deprecated and will be removed in a future version. Use prior_model.0.atomref.weight instead.",
category=DeprecationWarning,
stacklevel=2,
)
state_dict["prior_model.0.atomref.weight"] = state_dict[
"prior_model.atomref.weight"
]
del state_dict["prior_model.atomref.weight"]
model.load_state_dict(state_dict)
return model.to(device)

Expand Down

0 comments on commit d6bf65e

Please sign in to comment.