diff --git a/torchmdnet/data.py b/torchmdnet/data.py index 0a8036b8c..b9d263e3f 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -25,26 +25,21 @@ def __init__(self, dataset, dtype=torch.float64): super(FloatCastDatasetWrapper, self).__init__( dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter ) - self.dataset = dataset - self.dtype = dtype + self._dataset = dataset + self._dtype = dtype def len(self): - return len(self.dataset) + return len(self._dataset) def get(self, idx): - data = self.dataset.get(idx) + data = self._dataset.get(idx) for key, value in data: if torch.is_tensor(value) and torch.is_floating_point(value): - setattr(data, key, value.to(self.dtype)) + setattr(data, key, value.to(self._dtype)) return data - def __getattr__(self, name): - # Check if the attribute exists in the underlying dataset - if hasattr(self.dataset, name): - return getattr(self.dataset, name) - raise AttributeError( - f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'" - ) + def __getattr__(self, __name): + return getattr(self.__dict__["_dataset"], __name) class DataModule(LightningDataModule):