From 00c55903d81ce2b2c0500253d41e07b01c21a7df Mon Sep 17 00:00:00 2001 From: namsaraeva Date: Wed, 29 May 2024 17:21:29 +0200 Subject: [PATCH] change syntaxis --- src/sparcscore/ml/plmodels.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 5dd563e..56bc625 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -219,12 +219,12 @@ def training_step(self, batch): target = target.unsqueeze(1) output = self.network(data) # Forward pass, only one output - loss = self.configure_loss() + loss_func = self.configure_loss() if self.hparams["loss"] == "huber": # Huber loss - loss = loss(output, target, delta=self.hparams["huber_delta"], reduction='mean') + loss = loss_func(output, target, delta=self.hparams["huber_delta"], reduction='mean') else: # MSE - loss = loss(output, target) + loss = loss_func(output, target) self.log('loss/train', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('mse/train', self.mse(output, target), on_epoch=True, prog_bar=True) @@ -237,12 +237,12 @@ def validation_step(self, batch): target = target.unsqueeze(1) output = self.network(data) - loss = self.configure_loss() + loss_func = self.configure_loss() if self.hparams["loss"] == "huber": # Huber loss - loss = loss(output, target, delta=self.hparams["huber_delta"], reduction='mean') + loss = loss_func(output, target, delta=self.hparams["huber_delta"], reduction='mean') else: # MSE - loss = loss(output, target) + loss = loss_func(output, target) self.log('loss/val', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('mse/val', self.mse(output, target), on_epoch=True, prog_bar=True) @@ -255,12 +255,12 @@ def test_step(self, batch): target = target.unsqueeze(1) output = self.network(data) - loss = self.configure_loss() + loss_func = self.configure_loss() if self.hparams["loss"] == "huber": # Huber loss - loss = loss(output, target, delta=self.hparams["huber_delta"], reduction='mean') + loss = loss_func(output, target, delta=self.hparams["huber_delta"], reduction='mean') else: # MSE - loss = loss(output, target) + loss = loss_func(output, target) self.log('loss/test', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('mse/test', self.mse(output, target), on_epoch=True, prog_bar=True)