diff --git a/train.py b/train.py index 306242e2..1c5c4fdd 100644 --- a/train.py +++ b/train.py @@ -142,6 +142,7 @@ def validation(self, epoch): output = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() + tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1)