From b1de9a8384e421dea7d3b230ffd990ff161bf5c1 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 13 Feb 2024 16:25:31 +0100 Subject: [PATCH] Allow to pass prior arguments in the yaml file. Allow to configure more than one prior in the yaml file. --- torchmdnet/models/model.py | 2 +- torchmdnet/scripts/train.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index b0f4e2acd..c301fac28 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -208,7 +208,7 @@ def create_prior_models(args, dataset=None): else: prior_names.append(prior) prior_args.append({}) - if "prior_args" in args: + if "prior_args" in args and args["prior_args"] is not None: prior_args = args["prior_args"] if not isinstance(prior_args, list): prior_args = [prior_args] diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 44bbe3b99..e8a516e9b 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -62,7 +62,7 @@ def get_argparse(): # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') - parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset arguments, e.g. target property for QM9 or molecule for MD17. Need to be specified in JSON format i.e. \'{"molecules": "aspirin,benzene"}\'') + parser.add_argument('--dataset-arg', default=None, help='Additional dataset arguments. Needs to be a dictionary.') parser.add_argument('--coord-files', default=None, type=str, help='Custom coordinate files glob') parser.add_argument('--embed-files', default=None, type=str, help='Custom embedding files glob') parser.add_argument('--energy-files', default=None, type=str, help='Custom energy files glob') @@ -74,8 +74,8 @@ def get_argparse(): # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') - parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') - + parser.add_argument('--prior-model', type=str, choices=priors.__all__, default=None, help='Which prior model to use. It can be a string or a list of strings.', action="extend", nargs="*") + parser.add_argument('--prior-args', type=yaml.load, default=None, help='Additional prior arguments. Needs to be a dictionary or a list of dictionaries with the same size as the prior-model argument.', action="extend", nargs="*") # architectural args parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.')