From f240346274925743379b72c221a47208cc4bd1d0 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Wed, 2 Jan 2019 18:47:54 +0100 Subject: [PATCH] fix torch.no_grad for validation (#1156) bump version --- CHANGELOG.md | 3 +++ onmt/__init__.py | 2 +- onmt/trainer.py | 36 +++++++++++++++++++----------------- setup.py | 2 +- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8f9458932..86095333e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ ### New features ### Fixes and improvements +## [0.7.0](https://github.com/OpenNMT/OpenNMT-py/tree/0.7.0) (2019-01-02) +* Many fixes and code refactoring thanks @benopeters +* Migrated to Pytorch 1.0 ## [0.6.0](https://github.com/OpenNMT/OpenNMT-py/tree/0.6.0) (2018-11-28) * Many fixes and code improvements diff --git a/onmt/__init__.py b/onmt/__init__.py index 27acd2eada..cd42f1bf1d 100644 --- a/onmt/__init__.py +++ b/onmt/__init__.py @@ -17,4 +17,4 @@ __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, onmt.utils, onmt.modules, "Trainer"] -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/onmt/trainer.py b/onmt/trainer.py index 61b53188bc..4077255e52 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -9,6 +9,7 @@ users of this library) for the strategy things we do. """ +import torch import onmt.inputters as inputters import onmt.utils @@ -216,28 +217,29 @@ def validate(self, valid_iter): # Set model in validating mode. self.model.eval() - stats = onmt.utils.Statistics() + with torch.no_grad(): + stats = onmt.utils.Statistics() - for batch in valid_iter: - src = inputters.make_features(batch, 'src', self.data_type) - if self.data_type == 'text': - _, src_lengths = batch.src - elif self.data_type == 'audio': - src_lengths = batch.src_lengths - else: - src_lengths = None + for batch in valid_iter: + src = inputters.make_features(batch, 'src', self.data_type) + if self.data_type == 'text': + _, src_lengths = batch.src + elif self.data_type == 'audio': + src_lengths = batch.src_lengths + else: + src_lengths = None - tgt = inputters.make_features(batch, 'tgt') + tgt = inputters.make_features(batch, 'tgt') - # F-prop through the model. - outputs, attns = self.model(src, tgt, src_lengths) + # F-prop through the model. + outputs, attns = self.model(src, tgt, src_lengths) - # Compute loss. - batch_stats = self.valid_loss.monolithic_compute_loss( - batch, outputs, attns) + # Compute loss. + batch_stats = self.valid_loss.monolithic_compute_loss( + batch, outputs, attns) - # Update statistics. - stats.update(batch_stats) + # Update statistics. + stats.update(batch_stats) # Set model back to training mode. self.model.train() diff --git a/setup.py b/setup.py index 943c501dbb..ac54606c0f 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup(name='OpenNMT-py', description='A python implementation of OpenNMT', - version='0.6.0', + version='0.7.0', packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests', 'onmt.translate', 'onmt.decoders', 'onmt.inputters',