diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index 677bd956..7613be43 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -34,6 +34,7 @@ def __init__(self, max_z=None, dataset=None, trainable=False, enable=True): if max_z is None and dataset is None: raise ValueError("Can't instantiate Atomref prior, all arguments are None.") if dataset is None: + assert max_z is not None, "max_z must be provided if dataset is None." atomref = torch.zeros(max_z, 1) else: atomref = dataset.get_atomref()