From 114e12a9627edc3e6da6f774cdf2b31483eea56c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 13 Feb 2024 15:19:08 +0100 Subject: [PATCH] Better documentation for create_prior_models --- torchmdnet/models/model.py | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index adf30613..c05dddf5 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -167,7 +167,44 @@ def load_model(filepath, args=None, device="cpu", **kwargs): def create_prior_models(args, dataset=None): - """Parse the prior_model configuration option and create the prior models.""" + """Parse the prior_model configuration option and create the prior models. + + The information can be passed in different ways via the args dictionary, which must contain at least the key "prior_model". + + 1. A single prior model name and its arguments as a dictionary: + + ```python + args = { + "prior_model": "Atomref", + "prior_args": {"max_z": 100} + } + ``` + 2. A list of prior model names and their arguments as a list of dictionaries: + + ```python + + args = { + "prior_model": ["Atomref", "D2"], + "prior_args": [{"max_z": 100}, {"max_z": 100}] + } + ``` + + 3. A list of prior model names and their arguments as a dictionary: + + ```python + args = { + "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] + } + ``` + + Args: + args (dict): Arguments for the model. + dataset (torch_geometric.data.Dataset, optional): A dataset from which to extract the atomref values. Defaults to None. + + Returns: + list: A list of prior models. + + """ prior_models = [] if args["prior_model"]: prior_model = args["prior_model"]