Skip to content

Commit

Permalink
Fixed committor forward
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Dec 23, 2024
1 parent e59d038 commit e6d90f8
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions mlcolvar/cvs/committor/committor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,12 @@ def __init__(
if (options[o] is not False) and (options[o] is not None):
self.sigmoid = Custom_Sigmoid(**options[o])

def forward(self, x):
def forward_nn(self, x):
if self.preprocessing is not None:
x = self.preprocessing(x)
z = self.nn(x)
q = self.sigmoid(z)

return torch.hstack([z, q])
return z

def training_step(self, train_batch, batch_idx):
"""Compute and return the training loss and record metrics."""
Expand All @@ -139,9 +138,11 @@ def training_step(self, train_batch, batch_idx):
weights = x['weight'].clone()

# =================forward====================
out = self.forward(x)
z = out[:, 0]
q = out[:, 1]
z = self.forward_nn(x)
if self.sigmoid is not None:
q = self.sigmoid(z)
else:
q = z

# ===================loss=====================
if self.training:
Expand Down

0 comments on commit e6d90f8

Please sign in to comment.