diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 02ff741b..a2a80f90 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -51,9 +51,11 @@ def create_model(args, prior_model=None, mean=None, std=None): max_z=args["max_z"], check_errors=bool(args["check_errors"]), max_num_neighbors=args["max_num_neighbors"], - box_vecs=torch.tensor(args["box_vecs"], dtype=dtype) - if args["box_vecs"] is not None - else None, + box_vecs=( + torch.tensor(args["box_vecs"], dtype=dtype) + if args["box_vecs"] is not None + else None + ), dtype=dtype, ) @@ -164,8 +166,12 @@ def load_model(filepath, args=None, device="cpu", **kwargs): model = create_model(args) if delta_learning and "remove_ref_energy" in kwargs: if not kwargs["remove_ref_energy"]: - assert len(model.prior_model) > 0, "Atomref prior must be added during training (with enable=False) for total energy prediction." - assert isinstance(model.prior_model[-1], priors.Atomref), "I expected the last prior to be Atomref." + assert ( + len(model.prior_model) > 0 + ), "Atomref prior must be added during training (with enable=False) for total energy prediction." + assert isinstance( + model.prior_model[-1], priors.Atomref + ), "I expected the last prior to be Atomref." # Set the Atomref prior to enabled model.prior_model[-1].enable = True