Skip to content

Commit

Permalink
Update Ensemble docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Mar 18, 2024
1 parent 2b99e77 commit d071cd3
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,17 @@ def forward(


class Ensemble(torch.nn.ModuleList):
"""Average predictions over an ensemble of TorchMD-Net models"""
"""Average predictions over an ensemble of TorchMD-Net models.
def __init__(self, modules):
This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with.
Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
"""

def __init__(self, modules: List[nn.Module]):
for module in modules:
assert isinstance(module, TorchMD_Net)
super().__init__(modules)

def forward(
Expand Down

0 comments on commit d071cd3

Please sign in to comment.