Skip to content

Commit

Permalink
Remove some old TODOs that are now solved
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jan 26, 2024
1 parent d4778e1 commit 9da5319
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 92 deletions.
30 changes: 0 additions & 30 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,6 @@ def get_atomref(self, max_z=100):

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class ANI1XBase(ANIBase):
@property
Expand Down Expand Up @@ -350,16 +345,6 @@ def sample_iter(self, mol_ids=False):
if data := self.filter_and_pre_transform(data):
yield data

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class ANI1CCX(ANI1XBase):
__doc__ = ANIBase.__doc__
Expand Down Expand Up @@ -392,16 +377,6 @@ def sample_iter(self, mol_ids=False):
if data := self.filter_and_pre_transform(data):
yield data

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class ANI2X(ANIBase):
__doc__ = ANIBase.__doc__
Expand Down Expand Up @@ -477,8 +452,3 @@ def get_atomref(self, max_z=100):
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()
63 changes: 1 addition & 62 deletions torchmdnet/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,6 @@ def raw_url_name(self):
def raw_file_names(self):
return ["ani_md_bench.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class DrugBank(COMP6Base):
"""
DrugBank Benchmark. This benchmark is developed through a subsampling of the
Expand All @@ -249,16 +238,6 @@ class DrugBank(COMP6Base):
def raw_file_names(self):
return ["drugbank_testset.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class GDB07to09(COMP6Base):
"""
Expand All @@ -275,16 +254,6 @@ class GDB07to09(COMP6Base):
def raw_file_names(self):
return ["gdb11_07_test500.h5", "gdb11_08_test500.h5", "gdb11_09_test500.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class GDB10to13(COMP6Base):
"""
Expand All @@ -304,16 +273,6 @@ def raw_file_names(self):
"gdb13_13_test1000.h5",
]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class Tripeptides(COMP6Base):
"""
Expand All @@ -329,16 +288,6 @@ class Tripeptides(COMP6Base):
def raw_file_names(self):
return ["tripeptide_full.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class S66X8(COMP6Base):
"""
Expand All @@ -362,16 +311,6 @@ def raw_url_name(self):
def raw_file_names(self):
return ["s66x8_wb97x6-31gd.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class COMP6v1(Dataset):
"""
Expand All @@ -389,7 +328,7 @@ def __init__(

self.subsets = [
DS(root, transform, pre_transform, pre_filter)
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
]

self.num_samples = sum(len(subset) for subset in self.subsets)
Expand Down

0 comments on commit 9da5319

Please sign in to comment.