Skip to content

Commit

Permalink
Merge pull request #261 from RaulPPelaez/remove_todos
Browse files Browse the repository at this point in the history
Remove some old TODOs that are now solved
  • Loading branch information
stefdoerr authored Jan 29, 2024
2 parents 7e6dbeb + 9da5319 commit 9108514
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 92 deletions.
30 changes: 0 additions & 30 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,6 @@ def get_atomref(self, max_z=100):

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class ANI1XBase(ANIBase):
@property
Expand Down Expand Up @@ -225,16 +220,6 @@ def sample_iter(self, mol_ids=False):
if data := self.filter_and_pre_transform(data):
yield data

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class ANI1CCX(ANI1XBase):
__doc__ = ANIBase.__doc__
Expand Down Expand Up @@ -267,16 +252,6 @@ def sample_iter(self, mol_ids=False):
if data := self.filter_and_pre_transform(data):
yield data

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class ANI2X(ANIBase):
__doc__ = ANIBase.__doc__
Expand Down Expand Up @@ -352,8 +327,3 @@ def get_atomref(self, max_z=100):
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()
63 changes: 1 addition & 62 deletions torchmdnet/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,6 @@ def raw_url_name(self):
def raw_file_names(self):
return ["ani_md_bench.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class DrugBank(COMP6Base):
"""
DrugBank Benchmark. This benchmark is developed through a subsampling of the
Expand All @@ -161,16 +150,6 @@ class DrugBank(COMP6Base):
def raw_file_names(self):
return ["drugbank_testset.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class GDB07to09(COMP6Base):
"""
Expand All @@ -187,16 +166,6 @@ class GDB07to09(COMP6Base):
def raw_file_names(self):
return ["gdb11_07_test500.h5", "gdb11_08_test500.h5", "gdb11_09_test500.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class GDB10to13(COMP6Base):
"""
Expand All @@ -216,16 +185,6 @@ def raw_file_names(self):
"gdb13_13_test1000.h5",
]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class Tripeptides(COMP6Base):
"""
Expand All @@ -241,16 +200,6 @@ class Tripeptides(COMP6Base):
def raw_file_names(self):
return ["tripeptide_full.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class S66X8(COMP6Base):
"""
Expand All @@ -274,16 +223,6 @@ def raw_url_name(self):
def raw_file_names(self):
return ["s66x8_wb97x6-31gd.h5"]

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def download(self):
super().download()

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()


class COMP6v1(Dataset):
"""
Expand All @@ -301,7 +240,7 @@ def __init__(

self.subsets = [
DS(root, transform, pre_transform, pre_filter)
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
]

self.num_samples = sum(len(subset) for subset in self.subsets)
Expand Down

0 comments on commit 9108514

Please sign in to comment.