From 4744c9e9b4f9123a9cad3bdd670306afcf6c0264 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 31 Jan 2024 10:20:12 +0100 Subject: [PATCH] Add some type check for floating point arguments --- torchmdnet/models/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 2bfa8a36..b0f4e2ac 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -46,10 +46,10 @@ def create_model(args, prior_model=None, mean=None, std=None): rbf_type=args["rbf_type"], trainable_rbf=args["trainable_rbf"], activation=args["activation"], - cutoff_lower=args["cutoff_lower"], - cutoff_upper=args["cutoff_upper"], + cutoff_lower=float(args["cutoff_lower"]), + cutoff_upper=float(args["cutoff_upper"]), max_z=args["max_z"], - check_errors=args["check_errors"], + check_errors=bool(args["check_errors"]), max_num_neighbors=args["max_num_neighbors"], box_vecs=torch.tensor(args["box_vecs"], dtype=dtype) if args["box_vecs"] is not None