Skip to content

Commit

Permalink
Better documentation for create_prior_models
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Feb 13, 2024
1 parent 4643985 commit 114e12a
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,44 @@ def load_model(filepath, args=None, device="cpu", **kwargs):


def create_prior_models(args, dataset=None):
"""Parse the prior_model configuration option and create the prior models."""
"""Parse the prior_model configuration option and create the prior models.
The information can be passed in different ways via the args dictionary, which must contain at least the key "prior_model".
1. A single prior model name and its arguments as a dictionary:
```python
args = {
"prior_model": "Atomref",
"prior_args": {"max_z": 100}
}
```
2. A list of prior model names and their arguments as a list of dictionaries:
```python
args = {
"prior_model": ["Atomref", "D2"],
"prior_args": [{"max_z": 100}, {"max_z": 100}]
}
```
3. A list of prior model names and their arguments as a dictionary:
```python
args = {
"prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}]
}
```
Args:
args (dict): Arguments for the model.
dataset (torch_geometric.data.Dataset, optional): A dataset from which to extract the atomref values. Defaults to None.
Returns:
list: A list of prior models.
"""
prior_models = []
if args["prior_model"]:
prior_model = args["prior_model"]
Expand Down

0 comments on commit 114e12a

Please sign in to comment.