Skip to content

Commit

Permalink
Revert "Only overwrite the get function"
Browse files Browse the repository at this point in the history
This reverts commit 296f7c7.
  • Loading branch information
RaulPPelaez committed Jan 31, 2024
1 parent 36660bc commit 36e080e
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,31 @@
from torchmdnet.models.utils import scatter
from torchmdnet.models.utils import dtype_mapping

class FloatCastDatasetWrapper:
"""
A helper class to modify the `get` method of a dataset for casting floating point tensors.

class FloatCastDatasetWrapper(Dataset):
"""A wrapper around a torch_geometric dataset that casts all floating point
tensors to a given dtype.
"""
def __init__(self, dataset, dtype):

def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
)
self.dataset = dataset
self.dtype = dtype

def get_with_cast(self, idx):
"""
A modified `get` method that casts floating point tensors to the specified dtype.
"""
def len(self):
return len(self.dataset)

def get(self, idx):
data = self.dataset.get(idx)
for key, value in data.items(): # Assuming `data` is a dictionary
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
data[key] = value.to(self.dtype)
setattr(data, key, value.to(self.dtype))
return data

def adapt_floats_for_dtype(dataset, dtype=torch.float64):
"""
Modifies the `get` method of the given dataset to cast all floating point tensors to the specified dtype.
"""
adapter = FloatCastDatasetWrapper(dataset, dtype)
# Replace the original get method with the new one
setattr(dataset, 'get', adapter.get_with_cast)
return dataset
def __getattr__(self, name):
return getattr(self.dataset, name)


class DataModule(LightningDataModule):
Expand Down Expand Up @@ -83,7 +82,7 @@ def setup(self, stage):
self.hparams["dataset_root"], **dataset_arg
)

self.dataset = adapt_floats_for_dtype(
self.dataset = FloatCastDatasetWrapper(
self.dataset, dtype_mapping[self.hparams["precision"]]
)

Expand Down

0 comments on commit 36e080e

Please sign in to comment.