Skip to content

Commit

Permalink
Add some type check for floating point arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jan 31, 2024
1 parent af64cdb commit 4744c9e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def create_model(args, prior_model=None, mean=None, std=None):
rbf_type=args["rbf_type"],
trainable_rbf=args["trainable_rbf"],
activation=args["activation"],
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
cutoff_lower=float(args["cutoff_lower"]),
cutoff_upper=float(args["cutoff_upper"]),
max_z=args["max_z"],
check_errors=args["check_errors"],
check_errors=bool(args["check_errors"]),
max_num_neighbors=args["max_num_neighbors"],
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
Expand Down

0 comments on commit 4744c9e

Please sign in to comment.