Skip to content

Commit

Permalink
fix torch.no_grad for validation (#1156)
Browse files Browse the repository at this point in the history
bump version
  • Loading branch information
vince62s authored Jan 2, 2019
1 parent 27c6fd5 commit f240346
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
### New features

### Fixes and improvements
## [0.7.0](https://github.com/OpenNMT/OpenNMT-py/tree/0.7.0) (2019-01-02)
* Many fixes and code refactoring thanks @benopeters
* Migrated to Pytorch 1.0

## [0.6.0](https://github.com/OpenNMT/OpenNMT-py/tree/0.6.0) (2018-11-28)
* Many fixes and code improvements
Expand Down
2 changes: 1 addition & 1 deletion onmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
__all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models,
onmt.utils, onmt.modules, "Trainer"]

__version__ = "0.6.0"
__version__ = "0.7.0"
36 changes: 19 additions & 17 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
users of this library) for the strategy things we do.
"""

import torch
import onmt.inputters as inputters
import onmt.utils

Expand Down Expand Up @@ -216,28 +217,29 @@ def validate(self, valid_iter):
# Set model in validating mode.
self.model.eval()

stats = onmt.utils.Statistics()
with torch.no_grad():
stats = onmt.utils.Statistics()

for batch in valid_iter:
src = inputters.make_features(batch, 'src', self.data_type)
if self.data_type == 'text':
_, src_lengths = batch.src
elif self.data_type == 'audio':
src_lengths = batch.src_lengths
else:
src_lengths = None
for batch in valid_iter:
src = inputters.make_features(batch, 'src', self.data_type)
if self.data_type == 'text':
_, src_lengths = batch.src
elif self.data_type == 'audio':
src_lengths = batch.src_lengths
else:
src_lengths = None

tgt = inputters.make_features(batch, 'tgt')
tgt = inputters.make_features(batch, 'tgt')

# F-prop through the model.
outputs, attns = self.model(src, tgt, src_lengths)
# F-prop through the model.
outputs, attns = self.model(src, tgt, src_lengths)

# Compute loss.
batch_stats = self.valid_loss.monolithic_compute_loss(
batch, outputs, attns)
# Compute loss.
batch_stats = self.valid_loss.monolithic_compute_loss(
batch, outputs, attns)

# Update statistics.
stats.update(batch_stats)
# Update statistics.
stats.update(batch_stats)

# Set model back to training mode.
self.model.train()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(name='OpenNMT-py',
description='A python implementation of OpenNMT',
version='0.6.0',
version='0.7.0',

packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests',
'onmt.translate', 'onmt.decoders', 'onmt.inputters',
Expand Down

0 comments on commit f240346

Please sign in to comment.