Skip to content

Commit

Permalink
Blacken
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Feb 16, 2024
1 parent 8f0b8b2 commit f4b6827
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def create_model(args, prior_model=None, mean=None, std=None):
max_z=args["max_z"],
check_errors=bool(args["check_errors"]),
max_num_neighbors=args["max_num_neighbors"],
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
else None,
box_vecs=(
torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
else None
),
dtype=dtype,
)

Expand Down Expand Up @@ -164,8 +166,12 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
model = create_model(args)
if delta_learning and "remove_ref_energy" in kwargs:
if not kwargs["remove_ref_energy"]:
assert len(model.prior_model) > 0, "Atomref prior must be added during training (with enable=False) for total energy prediction."
assert isinstance(model.prior_model[-1], priors.Atomref), "I expected the last prior to be Atomref."
assert (
len(model.prior_model) > 0
), "Atomref prior must be added during training (with enable=False) for total energy prediction."
assert isinstance(
model.prior_model[-1], priors.Atomref
), "I expected the last prior to be Atomref."
# Set the Atomref prior to enabled
model.prior_model[-1].enable = True

Expand Down

0 comments on commit f4b6827

Please sign in to comment.