Skip to content

Commit

Permalink
Allow to pass prior arguments in the yaml file.
Browse files Browse the repository at this point in the history
Allow to configure more than one prior in the yaml file.
  • Loading branch information
RaulPPelaez committed Feb 13, 2024
1 parent 6d8e315 commit b1de9a8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_prior_models(args, dataset=None):
else:
prior_names.append(prior)
prior_args.append({})
if "prior_args" in args:
if "prior_args" in args and args["prior_args"] is not None:
prior_args = args["prior_args"]
if not isinstance(prior_args, list):
prior_args = [prior_args]
Expand Down
6 changes: 3 additions & 3 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_argparse():
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset arguments, e.g. target property for QM9 or molecule for MD17. Need to be specified in JSON format i.e. \'{"molecules": "aspirin,benzene"}\'')
parser.add_argument('--dataset-arg', default=None, help='Additional dataset arguments. Needs to be a dictionary.')
parser.add_argument('--coord-files', default=None, type=str, help='Custom coordinate files glob')
parser.add_argument('--embed-files', default=None, type=str, help='Custom embedding files glob')
parser.add_argument('--energy-files', default=None, type=str, help='Custom energy files glob')
Expand All @@ -74,8 +74,8 @@ def get_argparse():
# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')

parser.add_argument('--prior-model', type=str, choices=priors.__all__, default=None, help='Which prior model to use. It can be a string or a list of strings.', action="extend", nargs="*")
parser.add_argument('--prior-args', type=yaml.load, default=None, help='Additional prior arguments. Needs to be a dictionary or a list of dictionaries with the same size as the prior-model argument.', action="extend", nargs="*")
# architectural args
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.')
Expand Down

0 comments on commit b1de9a8

Please sign in to comment.