Skip to content

Commit

Permalink
Merge pull request #926 from vince62s/fix-bptt
Browse files Browse the repository at this point in the history
Fix bptt cf #891
  • Loading branch information
vince62s authored Aug 28, 2018
2 parents 57df743 + a085d85 commit e723f2a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
33 changes: 22 additions & 11 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _gradient_accumulation(self, true_batchs, normalization, total_stats,

for batch in true_batchs:
target_size = batch.tgt.size(0)
# Truncated BPTT
# Truncated BPTT: reminder not compatible with accum > 1
if self.trunc_size:
trunc_size = self.trunc_size
else:
Expand Down Expand Up @@ -287,20 +287,31 @@ def _gradient_accumulation(self, true_batchs, normalization, total_stats,
total_stats.update(batch_stats)
report_stats.update(batch_stats)

# 4. Update the parameters and statistics.
if self.grad_accum_count == 1:
# Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
self.optim.step()

# If truncated, don't backprop fully.
if dec_state is not None:
dec_state.detach()

# 3.bis Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))

# 4. Update the parameters and statistics.
self.optim.step()
# in case of multi step gradient accumulation,
# update only after accum batches
if self.grad_accum_count > 1:
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
self.optim.step()

def _start_report_manager(self, start_time=None):
"""
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def main(opt):
if opt.epochs:
raise AssertionError("-epochs is deprecated please use -train_steps.")

if opt.truncated_decoder > 0 and opt.accum_count > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")

if len(opt.gpuid) > 1:
multi_main(opt)
else:
Expand Down

0 comments on commit e723f2a

Please sign in to comment.