Skip to content

Commit

Permalink
Revert _MultiHeadAttention to MultiHeadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Oct 30, 2023
1 parent 6c98d69 commit 237f4b4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(

self.attention_layers = nn.ModuleList()
for _ in range(num_layers):
layer = _MultiHeadAttention(
layer = MultiHeadAttention(
hidden_channels,
num_rbf,
distance_influence,
Expand Down Expand Up @@ -182,7 +182,7 @@ def __repr__(self):
)


class _MultiHeadAttention(MessagePassing):
class MultiHeadAttention(MessagePassing):
def __init__(
self,
hidden_channels,
Expand All @@ -195,7 +195,7 @@ def __init__(
cutoff_upper,
dtype=torch.float,
):
super(_MultiHeadAttention, self).__init__(aggr="add", node_dim=0)
super(MultiHeadAttention, self).__init__(aggr="add", node_dim=0)
assert hidden_channels % num_heads == 0, (
f"The number of hidden channels ({hidden_channels}) "
f"must be evenly divisible by the number of "
Expand Down

0 comments on commit 237f4b4

Please sign in to comment.