Skip to content

Commit

Permalink
Blacken
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Feb 13, 2024
1 parent 2df7fd6 commit 4643985
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchmdnet/priors/atomref.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,20 @@ def __init__(self, max_z=None, dataset=None, trainable=False, enable=True):
if atomref.ndim == 1:
atomref = atomref.view(-1, 1)
self.register_buffer("initial_atomref", atomref)
self.atomref = nn.Embedding(len(atomref), 1, _freeze=not trainable, _weight=atomref)
self.atomref = nn.Embedding(
len(atomref), 1, _freeze=not trainable, _weight=atomref
)
self.enable = enable

def reset_parameters(self):
self.atomref.weight.data.copy_(self.initial_atomref)

def get_init_args(self):
return dict(max_z=self.initial_atomref.size(0), trainable=self.atomref.weight.requires_grad, enable=self.enable)
return dict(
max_z=self.initial_atomref.size(0),
trainable=self.atomref.weight.requires_grad,
enable=self.enable,
)

def pre_reduce(
self,
Expand Down

0 comments on commit 4643985

Please sign in to comment.