From ddf0687b5ef5370ec9743fb43c83d661e6b08788 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 20 Nov 2023 10:16:19 +0100 Subject: [PATCH] Mark Transformer as deprecated --- torchmdnet/models/torchmd_t.py | 4 ++-- torchmdnet/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 567482b66..e7478576f 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -9,9 +9,10 @@ act_class_mapping, scatter, ) +from torchmdnet.utils import deprecated_class - +@deprecated_class class TorchMD_T(nn.Module): r"""The TorchMD Transformer architecture. @@ -155,7 +156,6 @@ def forward( s: Optional[Tensor] = None, q: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - x = self.embedding(z) edge_index, edge_weight, _ = self.distance(pos, batch) diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index d7810897b..eacdb57f0 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -4,6 +4,8 @@ import torch from os.path import dirname, join, exists from lightning_utilities.core.rank_zero import rank_zero_warn +import functools +import warnings # fmt: off # Atomic masses are based on: @@ -217,3 +219,23 @@ def number(text): class MissingEnergyException(Exception): pass + + +def deprecated_class(cls): + """Decorator to mark classes as deprecated.""" + orig_init = cls.__init__ + + @functools.wraps(orig_init) + def wrapped_init(self, *args, **kwargs): + warnings.simplefilter( + "always", DeprecationWarning + ) # ensure all deprecation warnings are shown + warnings.warn( + f"{cls.__name__} is deprecated and will be removed in a future version.", + category=DeprecationWarning, + stacklevel=2, + ) + orig_init(self, *args, **kwargs) + + cls.__init__ = wrapped_init + return cls