From d071cd38324a96a87801098e6a179509efaf2db5 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 13:43:31 +0100 Subject: [PATCH] Update Ensemble docstring --- torchmdnet/models/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 6a0ecf7f..618a024a 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -434,9 +434,17 @@ def forward( class Ensemble(torch.nn.ModuleList): - """Average predictions over an ensemble of TorchMD-Net models""" + """Average predictions over an ensemble of TorchMD-Net models. - def __init__(self, modules): + This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with. + + Args: + modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. + """ + + def __init__(self, modules: List[nn.Module]): + for module in modules: + assert isinstance(module, TorchMD_Net) super().__init__(modules) def forward(