diff --git a/torchmdnet/datasets/ani.py b/torchmdnet/datasets/ani.py index 4149f0bf..1958a90e 100644 --- a/torchmdnet/datasets/ani.py +++ b/torchmdnet/datasets/ani.py @@ -269,11 +269,6 @@ def get_atomref(self, max_z=100): return refs.view(-1, 1) - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class ANI1XBase(ANIBase): @property @@ -350,16 +345,6 @@ def sample_iter(self, mol_ids=False): if data := self.filter_and_pre_transform(data): yield data - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class ANI1CCX(ANI1XBase): __doc__ = ANIBase.__doc__ @@ -392,16 +377,6 @@ def sample_iter(self, mol_ids=False): if data := self.filter_and_pre_transform(data): yield data - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class ANI2X(ANIBase): __doc__ = ANIBase.__doc__ @@ -477,8 +452,3 @@ def get_atomref(self, max_z=100): refs[key] = val * self.HARTREE_TO_EV return refs.view(-1, 1) - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() diff --git a/torchmdnet/datasets/comp6.py b/torchmdnet/datasets/comp6.py index e0aa1a10..354fc167 100644 --- a/torchmdnet/datasets/comp6.py +++ b/torchmdnet/datasets/comp6.py @@ -223,17 +223,6 @@ def raw_url_name(self): def raw_file_names(self): return ["ani_md_bench.h5"] - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - - class DrugBank(COMP6Base): """ DrugBank Benchmark. This benchmark is developed through a subsampling of the @@ -249,16 +238,6 @@ class DrugBank(COMP6Base): def raw_file_names(self): return ["drugbank_testset.h5"] - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class GDB07to09(COMP6Base): """ @@ -275,16 +254,6 @@ class GDB07to09(COMP6Base): def raw_file_names(self): return ["gdb11_07_test500.h5", "gdb11_08_test500.h5", "gdb11_09_test500.h5"] - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class GDB10to13(COMP6Base): """ @@ -304,16 +273,6 @@ def raw_file_names(self): "gdb13_13_test1000.h5", ] - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class Tripeptides(COMP6Base): """ @@ -329,16 +288,6 @@ class Tripeptides(COMP6Base): def raw_file_names(self): return ["tripeptide_full.h5"] - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class S66X8(COMP6Base): """ @@ -362,16 +311,6 @@ def raw_url_name(self): def raw_file_names(self): return ["s66x8_wb97x6-31gd.h5"] - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def download(self): - super().download() - - # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 - # TODO remove when fixed - def process(self): - super().process() - class COMP6v1(Dataset): """ @@ -389,7 +328,7 @@ def __init__( self.subsets = [ DS(root, transform, pre_transform, pre_filter) - for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8) + for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8) ] self.num_samples = sum(len(subset) for subset in self.subsets)