From 50b0ad11309d75d664ff875e62e9ef40f949a4f8 Mon Sep 17 00:00:00 2001 From: Andreas Boltres Date: Thu, 17 Mar 2022 12:26:05 +0100 Subject: [PATCH] Minor fix in loss scales for PredRNN++, PhyDNet and ST-Phy. --- vp_suite/models/phydnet.py | 4 ++-- vp_suite/models/predrnn_v2.py | 2 +- vp_suite/models/st_phy.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vp_suite/models/phydnet.py b/vp_suite/models/phydnet.py index b8e5771..c161348 100644 --- a/vp_suite/models/phydnet.py +++ b/vp_suite/models/phydnet.py @@ -129,8 +129,8 @@ def forward(self, x, pred_frames=1, **kwargs): for b in range(0, self.phycell.cell_list[0].input_dim): filters = self.phycell.cell_list[0].F.conv1.weight[:, b] moment = k2m(filters.double()).float() - moment_loss += torch.mean(self.moment_loss_scale * (moment - self.constraints) ** 2) - model_losses = {"moment regularization loss": moment_loss} + moment_loss += torch.mean((moment - self.constraints) ** 2) + model_losses = {"moment regularization loss": self.moment_loss_scale * moment_loss} else: model_losses = None diff --git a/vp_suite/models/predrnn_v2.py b/vp_suite/models/predrnn_v2.py index dd5b9bb..fc466f8 100644 --- a/vp_suite/models/predrnn_v2.py +++ b/vp_suite/models/predrnn_v2.py @@ -227,7 +227,7 @@ def forward(self, x, pred_frames: int = 1, **kwargs): predictions_patch = torch.stack(next_frames[-pred_frames:], dim=1) # [b, t_pred, cpp, h_, w_] predictions = self._reshape_patch_back(predictions_patch) # [b, t_pred, c, h, w] decouple_loss = torch.mean(torch.stack(decouple_loss, dim=0)) - return predictions, {"ST-LSTM decouple loss": decouple_loss} + return predictions, {"ST-LSTM decouple loss": self.decoupling_loss_scale * decouple_loss} def _reshape_patch(self, x): b, t, c, h, w = x.shape diff --git a/vp_suite/models/st_phy.py b/vp_suite/models/st_phy.py index 8f012f0..a3404a1 100644 --- a/vp_suite/models/st_phy.py +++ b/vp_suite/models/st_phy.py @@ -171,9 +171,10 @@ def forward(self, x, pred_frames=1, **kwargs): filters = self.phycell_list[0].F.conv1.weight[:, b] moment = k2m(filters.double()).float() moment_loss += torch.mean(self.moment_loss_scale * (moment - self.constraints) ** 2) + decoupling_loss = torch.mean(torch.stack(decouple_loss, dim=0)) model_losses = { - "moment regularization loss": moment_loss, - "memory decoupling loss": torch.mean(torch.stack(decouple_loss, dim=0)), + "moment regularization loss": self.moment_loss_scale * moment_loss, + "memory decoupling loss": self.decoupling_loss_scale * decoupling_loss, } else: model_losses = None