diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index 82c4f087..7e8d3174 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -47,7 +47,7 @@ def __init__(self, max_z=None, dataset=None, trainable=False, enable=True): if atomref.ndim == 1: atomref = atomref.view(-1, 1) self.register_buffer("initial_atomref", atomref) - self.atomref = nn.Embedding(len(atomref), 1, freeze=trainable and not enable) + self.atomref = nn.Embedding(len(atomref), 1, _freeze=trainable and not enable) self.atomref.weight.data.copy_(atomref) self.enable = enable