Skip to content

Commit

Permalink
Minor fix in loss scales for PredRNN++, PhyDNet and ST-Phy.
Browse files Browse the repository at this point in the history
  • Loading branch information
Flunzmas committed Mar 17, 2022
1 parent 321cc52 commit 50b0ad1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions vp_suite/models/phydnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def forward(self, x, pred_frames=1, **kwargs):
for b in range(0, self.phycell.cell_list[0].input_dim):
filters = self.phycell.cell_list[0].F.conv1.weight[:, b]
moment = k2m(filters.double()).float()
moment_loss += torch.mean(self.moment_loss_scale * (moment - self.constraints) ** 2)
model_losses = {"moment regularization loss": moment_loss}
moment_loss += torch.mean((moment - self.constraints) ** 2)
model_losses = {"moment regularization loss": self.moment_loss_scale * moment_loss}
else:
model_losses = None

Expand Down
2 changes: 1 addition & 1 deletion vp_suite/models/predrnn_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(self, x, pred_frames: int = 1, **kwargs):
predictions_patch = torch.stack(next_frames[-pred_frames:], dim=1) # [b, t_pred, cpp, h_, w_]
predictions = self._reshape_patch_back(predictions_patch) # [b, t_pred, c, h, w]
decouple_loss = torch.mean(torch.stack(decouple_loss, dim=0))
return predictions, {"ST-LSTM decouple loss": decouple_loss}
return predictions, {"ST-LSTM decouple loss": self.decoupling_loss_scale * decouple_loss}

def _reshape_patch(self, x):
b, t, c, h, w = x.shape
Expand Down
5 changes: 3 additions & 2 deletions vp_suite/models/st_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,10 @@ def forward(self, x, pred_frames=1, **kwargs):
filters = self.phycell_list[0].F.conv1.weight[:, b]
moment = k2m(filters.double()).float()
moment_loss += torch.mean(self.moment_loss_scale * (moment - self.constraints) ** 2)
decoupling_loss = torch.mean(torch.stack(decouple_loss, dim=0))
model_losses = {
"moment regularization loss": moment_loss,
"memory decoupling loss": torch.mean(torch.stack(decouple_loss, dim=0)),
"moment regularization loss": self.moment_loss_scale * moment_loss,
"memory decoupling loss": self.decoupling_loss_scale * decoupling_loss,
}
else:
model_losses = None
Expand Down

0 comments on commit 50b0ad1

Please sign in to comment.