Skip to content

Commit

Permalink
Mark Transformer as deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Nov 20, 2023
1 parent be6a214 commit ddf0687
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
act_class_mapping,
scatter,
)
from torchmdnet.utils import deprecated_class



@deprecated_class
class TorchMD_T(nn.Module):
r"""The TorchMD Transformer architecture.
Expand Down Expand Up @@ -155,7 +156,6 @@ def forward(
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:

x = self.embedding(z)

edge_index, edge_weight, _ = self.distance(pos, batch)
Expand Down
22 changes: 22 additions & 0 deletions torchmdnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from os.path import dirname, join, exists
from lightning_utilities.core.rank_zero import rank_zero_warn
import functools
import warnings

# fmt: off
# Atomic masses are based on:
Expand Down Expand Up @@ -217,3 +219,23 @@ def number(text):

class MissingEnergyException(Exception):
pass


def deprecated_class(cls):
"""Decorator to mark classes as deprecated."""
orig_init = cls.__init__

@functools.wraps(orig_init)
def wrapped_init(self, *args, **kwargs):
warnings.simplefilter(
"always", DeprecationWarning
) # ensure all deprecation warnings are shown
warnings.warn(
f"{cls.__name__} is deprecated and will be removed in a future version.",
category=DeprecationWarning,
stacklevel=2,
)
orig_init(self, *args, **kwargs)

cls.__init__ = wrapped_init
return cls

0 comments on commit ddf0687

Please sign in to comment.