From f35b34e91c8b2e8918276e993e5add49f92e7f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 30 Jan 2020 18:39:37 +0100 Subject: [PATCH 01/12] add lambda_cosine, move normalization to compute_loss, adapt stats --- onmt/models/model.py | 16 +++++++++++-- onmt/opts.py | 2 ++ onmt/trainer.py | 22 +++++++++++------ onmt/utils/loss.py | 52 +++++++++++++++++++++++++++++++--------- onmt/utils/statistics.py | 18 ++++++++++++-- 5 files changed, 88 insertions(+), 22 deletions(-) diff --git a/onmt/models/model.py b/onmt/models/model.py index 920adcc981..fa41e5a699 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -1,5 +1,6 @@ """ Onmt NMT Model base class definition """ import torch.nn as nn +import torch class NMTModel(nn.Module): @@ -17,7 +18,8 @@ def __init__(self, encoder, decoder): self.encoder = encoder self.decoder = decoder - def forward(self, src, tgt, lengths, bptt=False, with_align=False): + def forward(self, src, tgt, lengths, bptt=False, + with_align=False, encode_tgt=False): """Forward propagate a `src` and `tgt` pair for training. Possible initialized with a beginning decoder state. @@ -44,12 +46,22 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False): enc_state, memory_bank, lengths = self.encoder(src, lengths) + if encode_tgt: + # tgt for zero shot alignment loss + tgt_lengths = torch.Tensor(tgt.size(1))\ + .type_as(memory_bank) \ + .long() \ + .fill_(tgt.size(0)) + embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths) + else: + memory_bank_tgt = None + if bptt is False: self.decoder.init_state(src, memory_bank, enc_state) dec_out, attns = self.decoder(dec_in, memory_bank, memory_lengths=lengths, with_align=with_align) - return dec_out, attns + return dec_out, attns, memory_bank, memory_bank_tgt def update_dropout(self, dropout): self.encoder.update_dropout(dropout) diff --git a/onmt/opts.py b/onmt/opts.py index af47f79836..e54e6353b4 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -193,6 +193,8 @@ def model_opts(parser): help='Train a coverage attention layer.') group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0, help='Lambda value for coverage loss of See et al (2017)') + group.add('--lambda_cosine', '-lambda_cosine', type=float, default=0.0, + help='Lambda value for cosine alignment loss #TODO cite') group.add('--loss_scale', '-loss_scale', type=float, default=0, help="For FP16 training, the static loss scale to use. If not " "set, the loss scale is dynamically computed.") diff --git a/onmt/trainer.py b/onmt/trainer.py index 4328ca52ea..60653f72cc 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -70,7 +70,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_dtype=opt.model_dtype, earlystopper=earlystopper, dropout=dropout, - dropout_steps=dropout_steps) + dropout_steps=dropout_steps, + encode_tgt=True if opt.lambda_cosine > 0 else False) return trainer @@ -107,7 +108,8 @@ def __init__(self, model, train_loss, valid_loss, optim, n_gpu=1, gpu_rank=1, gpu_verbose_level=0, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', - earlystopper=None, dropout=[0.3], dropout_steps=[0]): + earlystopper=None, dropout=[0.3], dropout_steps=[0], + encode_tgt=False): # Basic attributes. self.model = model self.train_loss = train_loss @@ -132,6 +134,7 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps + self.encode_tgt = encode_tgt for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -314,11 +317,13 @@ def validate(self, valid_iter, moving_average=None): tgt = batch.tgt # F-prop through the model. - outputs, attns = valid_model(src, tgt, src_lengths, - with_align=self.with_align) + outputs, attns, enc_src, enc_tgt = valid_model( + src, tgt, src_lengths, + with_align=self.with_align) # Compute loss. - _, batch_stats = self.valid_loss(batch, outputs, attns) + _, batch_stats = self.valid_loss( + batch, outputs, attns, enc_src, enc_tgt) # Update statistics. stats.update(batch_stats) @@ -361,8 +366,9 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, if self.accum_count == 1: self.optim.zero_grad() - outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt, - with_align=self.with_align) + outputs, attns, enc_src, enc_tgt = self.model( + src, tgt, src_lengths, bptt=bptt, + with_align=self.with_align, encode_tgt=self.encode_tgt) bptt = True # 3. Compute loss. @@ -371,6 +377,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, batch, outputs, attns, + enc_src, + enc_tgt, normalization=normalization, shard_size=self.shard_size, trunc_start=j, diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index c48f0d3d21..71f1cf455d 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -58,7 +58,7 @@ def build_loss_compute(model, tgt_field, opt, train=True): else: compute = NMTLossCompute( criterion, loss_gen, lambda_coverage=opt.lambda_coverage, - lambda_align=opt.lambda_align) + lambda_align=opt.lambda_align, lambda_cosine=opt.lambda_cosine) compute.to(device) return compute @@ -123,6 +123,8 @@ def __call__(self, batch, output, attns, + enc_src, + enc_tgt, normalization=1.0, shard_size=0, trunc_start=0, @@ -157,18 +159,19 @@ def __call__(self, if trunc_size is None: trunc_size = batch.tgt.size(0) - trunc_start trunc_range = (trunc_start, trunc_start + trunc_size) - shard_state = self._make_shard_state(batch, output, trunc_range, attns) + shard_state = self._make_shard_state( + batch, output, enc_src, enc_tgt, trunc_range, attns) if shard_size == 0: - loss, stats = self._compute_loss(batch, **shard_state) - return loss / float(normalization), stats + loss, stats = self._compute_loss(batch, normalization, **shard_state) + return loss, stats batch_stats = onmt.utils.Statistics() for shard in shards(shard_state, shard_size): loss, stats = self._compute_loss(batch, **shard) - loss.div(float(normalization)).backward() + loss.backward() batch_stats.update(stats) return None, batch_stats - def _stats(self, loss, scores, target): + def _stats(self, loss, cosine_loss, scores, target, num_ex): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. @@ -182,7 +185,9 @@ def _stats(self, loss, scores, target): non_padding = target.ne(self.padding_idx) num_correct = pred.eq(target).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() - return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) + return onmt.utils.Statistics( + loss.item(), cosine_loss.item() if cosine_loss is not None else 0, + num_non_padding, num_correct, num_ex) def _bottle(self, _v): return _v.view(-1, _v.size(2)) @@ -227,15 +232,18 @@ class NMTLossCompute(LossComputeBase): """ def __init__(self, criterion, generator, normalization="sents", - lambda_coverage=0.0, lambda_align=0.0): + lambda_coverage=0.0, lambda_align=0.0, lambda_cosine=0.0): super(NMTLossCompute, self).__init__(criterion, generator) self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align + self.lambda_cosine = lambda_cosine - def _make_shard_state(self, batch, output, range_, attns=None): + def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None): shard_state = { "output": output, "target": batch.tgt[range_[0] + 1: range_[1], :, 0], + "enc_src": enc_src, + "enc_tgt": enc_tgt } if self.lambda_coverage != 0.0: coverage = attns.get("coverage", None) @@ -275,7 +283,7 @@ def _make_shard_state(self, batch, output, range_, attns=None): }) return shard_state - def _compute_loss(self, batch, output, target, std_attn=None, + def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): bottled_output = self._bottle(output) @@ -284,6 +292,7 @@ def _compute_loss(self, batch, output, target, std_attn=None, gtruth = target.view(-1) loss = self.criterion(scores, gtruth) + if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( std_attn=std_attn, coverage_attn=coverage_attn) @@ -296,7 +305,28 @@ def _compute_loss(self, batch, output, target, std_attn=None, align_loss = self._compute_alignement_loss( align_head=align_head, ref_align=ref_align) loss += align_loss - stats = self._stats(loss.clone(), scores, gtruth) + + loss = loss/float(normalization) + + if self.lambda_cosine != 0.0: + max_src = enc_src.max(axis=0)[0] + max_tgt = enc_tgt.max(axis=0)[0] + cosine_loss = torch.nn.functional.cosine_similarity( + max_src.float(), max_tgt.float(), dim=1) + ones = torch.ones(cosine_loss.size()).to(cosine_loss.device) + cosine_loss = ones - cosine_loss + num_ex = cosine_loss.size(0) + cosine_loss = cosine_loss.sum() + loss += self.lambda_cosine * (cosine_loss / num_ex) + else: + cosine_loss = None + num_ex = 0 + + + stats = self._stats(loss.clone() * normalization, + cosine_loss.clone() if cosine_loss is not None + else cosine_loss, + scores, gtruth, num_ex) return loss, stats diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 896d98c74d..a6dc9583ef 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -17,12 +17,14 @@ class Statistics(object): * elapsed time """ - def __init__(self, loss=0, n_words=0, n_correct=0): + def __init__(self, loss=0, cosine_loss=0, n_words=0, n_correct=0, num_ex=0): self.loss = loss self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 self.start_time = time.time() + self.cosine_loss = cosine_loss + self.num_ex = num_ex @staticmethod def all_gather_stats(stat, max_size=4096): @@ -81,6 +83,10 @@ def update(self, stat, update_n_src_words=False): self.loss += stat.loss self.n_words += stat.n_words self.n_correct += stat.n_correct + # print("LOSS update", stat.loss) + # print("ZS_LOSS update", stat.zs_loss) + self.cosine_loss += stat.cosine_loss + self.num_ex += stat.num_ex if update_n_src_words: self.n_src_words += stat.n_src_words @@ -97,6 +103,10 @@ def ppl(self): """ compute perplexity """ return math.exp(min(self.loss / self.n_words, 100)) + def cos(self): + # print("ZS LOSS", self.zs_loss) + return self.cosine_loss / self.num_ex + def elapsed_time(self): """ compute elapsed time """ return time.time() - self.start_time @@ -113,8 +123,12 @@ def output(self, step, num_steps, learning_rate, start): step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) + if self.cosine_loss != 0: + cos_log = "cos: %4.2f; " % (self.cos()) + else: + cos_log = "" logger.info( - ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + + ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + cos_log + "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") % (step_fmt, self.accuracy(), From 845c989e832322b99ea643f46b05d2537b3100c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 30 Jan 2020 19:08:27 +0100 Subject: [PATCH 02/12] fix some flake --- onmt/utils/loss.py | 12 +++++++----- onmt/utils/statistics.py | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 71f1cf455d..e55ad15f26 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -162,7 +162,8 @@ def __call__(self, shard_state = self._make_shard_state( batch, output, enc_src, enc_tgt, trunc_range, attns) if shard_size == 0: - loss, stats = self._compute_loss(batch, normalization, **shard_state) + loss, stats = self._compute_loss(batch, normalization, + **shard_state) return loss, stats batch_stats = onmt.utils.Statistics() for shard in shards(shard_state, shard_size): @@ -238,7 +239,8 @@ def __init__(self, criterion, generator, normalization="sents", self.lambda_align = lambda_align self.lambda_cosine = lambda_cosine - def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None): + def _make_shard_state(self, batch, output, enc_src, enc_tgt, + range_, attns=None): shard_state = { "output": output, "target": batch.tgt[range_[0] + 1: range_[1], :, 0], @@ -283,7 +285,8 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None) }) return shard_state - def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt, std_attn=None, + def _compute_loss(self, batch, normalization, output, target, + enc_src, enc_tgt, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): bottled_output = self._bottle(output) @@ -312,7 +315,7 @@ def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt, max_src = enc_src.max(axis=0)[0] max_tgt = enc_tgt.max(axis=0)[0] cosine_loss = torch.nn.functional.cosine_similarity( - max_src.float(), max_tgt.float(), dim=1) + max_src.float(), max_tgt.float(), dim=1) ones = torch.ones(cosine_loss.size()).to(cosine_loss.device) cosine_loss = ones - cosine_loss num_ex = cosine_loss.size(0) @@ -322,7 +325,6 @@ def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt, cosine_loss = None num_ex = 0 - stats = self._stats(loss.clone() * normalization, cosine_loss.clone() if cosine_loss is not None else cosine_loss, diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index a6dc9583ef..345f1a918f 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -17,7 +17,8 @@ class Statistics(object): * elapsed time """ - def __init__(self, loss=0, cosine_loss=0, n_words=0, n_correct=0, num_ex=0): + def __init__(self, loss=0, cosine_loss=0, n_words=0, + n_correct=0, num_ex=0): self.loss = loss self.n_words = n_words self.n_correct = n_correct From 84e472f3113befa9a3ed78639b1cce7fd990ebc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Mon, 3 Feb 2020 12:45:44 +0100 Subject: [PATCH 03/12] fix forward tests --- onmt/tests/test_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onmt/tests/test_models.py b/onmt/tests/test_models.py index 76dc5b48b1..638f4fc84c 100644 --- a/onmt/tests/test_models.py +++ b/onmt/tests/test_models.py @@ -134,7 +134,7 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1): test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) - outputs, attn = model(test_src, test_tgt, test_length) + outputs, attn, _, _ = model(test_src, test_tgt, test_length) outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) @@ -168,7 +168,7 @@ def imagemodel_forward(self, opt, tgt_l=2, bsize=1, h=15, w=17): h=h, w=w, bsize=bsize, tgt_l=tgt_l) - outputs, attn = model(test_src, test_tgt, test_length) + outputs, attn, _, _ = model(test_src, test_tgt, test_length) outputsize = torch.zeros(tgt_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) @@ -206,7 +206,7 @@ def audiomodel_forward(self, opt, tgt_l=7, bsize=3, t=37): sample_rate=opt.sample_rate, window_size=opt.window_size, t=t, tgt_l=tgt_l) - outputs, attn = model(test_src, test_tgt, test_length) + outputs, attn, _, _ = model(test_src, test_tgt, test_length) outputsize = torch.zeros(tgt_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) From 1a6e7373423d35b0fd029cbbce6f3931ca347193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Mon, 3 Feb 2020 19:18:19 +0100 Subject: [PATCH 04/12] disable sharded loss if lambda_cosine --- onmt/utils/loss.py | 15 +++++++++------ onmt/utils/parse.py | 4 ++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index e55ad15f26..2a37b35594 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -167,7 +167,7 @@ def __call__(self, return loss, stats batch_stats = onmt.utils.Statistics() for shard in shards(shard_state, shard_size): - loss, stats = self._compute_loss(batch, **shard) + loss, stats = self._compute_loss(batch, normalization, **shard) loss.backward() batch_stats.update(stats) return None, batch_stats @@ -243,9 +243,7 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None): shard_state = { "output": output, - "target": batch.tgt[range_[0] + 1: range_[1], :, 0], - "enc_src": enc_src, - "enc_tgt": enc_tgt + "target": batch.tgt[range_[0] + 1: range_[1], :, 0] } if self.lambda_coverage != 0.0: coverage = attns.get("coverage", None) @@ -283,10 +281,15 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt, "align_head": attn_align, "ref_align": ref_align[:, range_[0] + 1: range_[1], :] }) + if self.lambda_cosine != 0.0: + shard_state.update({ + "enc_src": enc_src, + "enc_tgt": enc_tgt + }) return shard_state def _compute_loss(self, batch, normalization, output, target, - enc_src, enc_tgt, std_attn=None, + enc_src=None, enc_tgt=None, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): bottled_output = self._bottle(output) @@ -400,7 +403,7 @@ def shards(state, shard_size, eval_only=False): # over the shards, not over the keys: therefore, the values need # to be re-zipped by shard and then each shard can be paired # with the keys. - for shard_tensors in zip(*values): + for i, shard_tensors in enumerate(zip(*values)): yield dict(zip(keys, shard_tensors)) # Assumed backprop'd diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 273dae3dba..993460f075 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -120,6 +120,10 @@ def validate_train_opts(cls, opt): assert len(opt.attention_dropout) == len(opt.dropout_steps), \ "Number of attention_dropout values must match accum_steps values" + assert not(opt.max_generator_batches > 0 and opt.lambda_cosine != 0), \ + "-lambda_cosine loss is not implemented for max_generator_batches > 0." + + @classmethod def validate_translate_opts(cls, opt): if opt.beam_size != 1 and opt.random_sampling_topk != 1: From e99daafae3d36e4250dd6370846c92e1950db8b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Mon, 3 Feb 2020 19:22:34 +0100 Subject: [PATCH 05/12] fix some flake --- onmt/utils/parse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 993460f075..ac6ddf6820 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -121,8 +121,8 @@ def validate_train_opts(cls, opt): "Number of attention_dropout values must match accum_steps values" assert not(opt.max_generator_batches > 0 and opt.lambda_cosine != 0), \ - "-lambda_cosine loss is not implemented for max_generator_batches > 0." - + "-lambda_cosine loss is not implemented " \ + "for max_generator_batches > 0." @classmethod def validate_translate_opts(cls, opt): From 1faea3cc27381fa079f6a50f15c88274aaf007ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 4 Feb 2020 11:37:17 +0100 Subject: [PATCH 06/12] move cosine loss compute to function, fix some args --- onmt/modules/copy_generator.py | 19 +++++++++++++++---- onmt/utils/loss.py | 22 +++++++++++++--------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py index 900096cf4d..503509e732 100644 --- a/onmt/modules/copy_generator.py +++ b/onmt/modules/copy_generator.py @@ -186,14 +186,14 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, self.tgt_vocab = tgt_vocab self.normalize_by_length = normalize_by_length - def _make_shard_state(self, batch, output, range_, attns): + def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns): """See base class for args description.""" if getattr(batch, "alignment", None) is None: raise AssertionError("using -copy_attn you need to pass in " "-dynamic_dict during preprocess stage.") shard_state = super(CopyGeneratorLossCompute, self)._make_shard_state( - batch, output, range_, attns) + batch, output, enc_src, enc_tgt, range_, attns) shard_state.update({ "copy_attn": attns.get("copy"), @@ -201,7 +201,8 @@ def _make_shard_state(self, batch, output, range_, attns): }) return shard_state - def _compute_loss(self, batch, output, target, copy_attn, align, + def _compute_loss(self, batch, normalization, output, target, + copy_attn, align, enc_src=None, enc_tgt=None, std_attn=None, coverage_attn=None): """Compute the loss. @@ -244,8 +245,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align, offset_align = align[correct_mask] + len(self.tgt_vocab) target_data[correct_mask] += offset_align + if self.lambda_cosine != 0.0: + cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt) + loss += self.lambda_cosine * (cosine_loss / num_ex) + else: + cosine_loss = None + num_ex = 0 + # Compute sum of perplexities for stats - stats = self._stats(loss.sum().clone(), scores_data, target_data) + stats = self._stats(loss.sum().clone(), + cosine_loss.clone() if cosine_loss is not None + else cosine_loss, + scores_data, target_data, num_ex) # this part looks like it belongs in CopyGeneratorLoss if self.normalize_by_length: diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 2a37b35594..cef8326305 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -92,7 +92,8 @@ def __init__(self, criterion, generator): def padding_idx(self): return self.criterion.ignore_index - def _make_shard_state(self, batch, output, range_, attns=None): + def _make_shard_state(self, batch, enc_src, enc_tgt, + output, range_, attns=None): """ Make shard state dictionary for shards() to return iterable shards for efficient loss computation. Subclass must define @@ -315,14 +316,7 @@ def _compute_loss(self, batch, normalization, output, target, loss = loss/float(normalization) if self.lambda_cosine != 0.0: - max_src = enc_src.max(axis=0)[0] - max_tgt = enc_tgt.max(axis=0)[0] - cosine_loss = torch.nn.functional.cosine_similarity( - max_src.float(), max_tgt.float(), dim=1) - ones = torch.ones(cosine_loss.size()).to(cosine_loss.device) - cosine_loss = ones - cosine_loss - num_ex = cosine_loss.size(0) - cosine_loss = cosine_loss.sum() + cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt) loss += self.lambda_cosine * (cosine_loss / num_ex) else: cosine_loss = None @@ -340,6 +334,16 @@ def _compute_coverage_loss(self, std_attn, coverage_attn): covloss *= self.lambda_coverage return covloss + def _compute_cosine_loss(self, enc_src, enc_tgt): + max_src = enc_src.max(axis=0)[0] + max_tgt = enc_tgt.max(axis=0)[0] + cosine_loss = torch.nn.functional.cosine_similarity( + max_src.float(), max_tgt.float(), dim=1) + ones = torch.ones(cosine_loss.size()).to(cosine_loss.device) + cosine_loss = ones - cosine_loss + num_ex = cosine_loss.size(0) + return cosine_loss.sum(), num_ex + def _compute_alignement_loss(self, align_head, ref_align): """Compute loss between 2 partial alignment matrix.""" # align_head contains value in [0, 1) presenting attn prob, From f8bc7f4142f92071de393bce0b2a358f3e0fcdf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 4 Feb 2020 11:46:38 +0100 Subject: [PATCH 07/12] fix flake --- onmt/modules/copy_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py index 503509e732..b5959ce92e 100644 --- a/onmt/modules/copy_generator.py +++ b/onmt/modules/copy_generator.py @@ -186,7 +186,8 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, self.tgt_vocab = tgt_vocab self.normalize_by_length = normalize_by_length - def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns): + def _make_shard_state(self, batch, output, enc_src, enc_tgt, + range_, attns): """See base class for args description.""" if getattr(batch, "alignment", None) is None: raise AssertionError("using -copy_attn you need to pass in " From 9d263606251d7632ec3131ea4f1b26458c53173d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Fri, 7 Feb 2020 15:59:00 +0100 Subject: [PATCH 08/12] add arxiv link --- onmt/opts.py | 3 ++- onmt/utils/statistics.py | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onmt/opts.py b/onmt/opts.py index 609a9d9be5..191fc5c151 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -194,7 +194,8 @@ def model_opts(parser): group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0, help='Lambda value for coverage loss of See et al (2017)') group.add('--lambda_cosine', '-lambda_cosine', type=float, default=0.0, - help='Lambda value for cosine alignment loss #TODO cite') + help='Lambda value for cosine alignment loss ' + 'of https://arxiv.org/abs/1903.07091 ') group.add('--loss_scale', '-loss_scale', type=float, default=0, help="For FP16 training, the static loss scale to use. If not " "set, the loss scale is dynamically computed.") diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 345f1a918f..87a1e7f8f1 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -84,8 +84,6 @@ def update(self, stat, update_n_src_words=False): self.loss += stat.loss self.n_words += stat.n_words self.n_correct += stat.n_correct - # print("LOSS update", stat.loss) - # print("ZS_LOSS update", stat.zs_loss) self.cosine_loss += stat.cosine_loss self.num_ex += stat.num_ex @@ -105,7 +103,7 @@ def ppl(self): return math.exp(min(self.loss / self.n_words, 100)) def cos(self): - # print("ZS LOSS", self.zs_loss) + """ normalize cosine distance per example""" return self.cosine_loss / self.num_ex def elapsed_time(self): From 306b2e59aa2f344291205170f2c3a3d330025dfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Fri, 7 Feb 2020 17:38:02 +0100 Subject: [PATCH 09/12] broadcast instead of explicitly create ones --- onmt/utils/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index cef8326305..bf1abf4caf 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -339,8 +339,7 @@ def _compute_cosine_loss(self, enc_src, enc_tgt): max_tgt = enc_tgt.max(axis=0)[0] cosine_loss = torch.nn.functional.cosine_similarity( max_src.float(), max_tgt.float(), dim=1) - ones = torch.ones(cosine_loss.size()).to(cosine_loss.device) - cosine_loss = ones - cosine_loss + cosine_loss = 1 - cosine_loss num_ex = cosine_loss.size(0) return cosine_loss.sum(), num_ex From 8616b99f96a58a22b485a2dbd03c9901c6d92f9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Mon, 17 Feb 2020 18:22:59 +0100 Subject: [PATCH 10/12] return encoder representations only if necessary --- onmt/models/model.py | 17 +++++++++-------- onmt/trainer.py | 33 ++++++++++++++++++++++++--------- onmt/utils/loss.py | 4 ++-- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/onmt/models/model.py b/onmt/models/model.py index fa41e5a699..a3ce348413 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -46,6 +46,13 @@ def forward(self, src, tgt, lengths, bptt=False, enc_state, memory_bank, lengths = self.encoder(src, lengths) + if bptt is False: + self.decoder.init_state(src, memory_bank, enc_state) + + dec_out, attns = self.decoder(dec_in, memory_bank, + memory_lengths=lengths, + with_align=with_align) + if encode_tgt: # tgt for zero shot alignment loss tgt_lengths = torch.Tensor(tgt.size(1))\ @@ -53,15 +60,9 @@ def forward(self, src, tgt, lengths, bptt=False, .long() \ .fill_(tgt.size(0)) embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths) - else: - memory_bank_tgt = None + return dec_out, attns, memory_bank, memory_bank_tgt - if bptt is False: - self.decoder.init_state(src, memory_bank, enc_state) - dec_out, attns = self.decoder(dec_in, memory_bank, - memory_lengths=lengths, - with_align=with_align) - return dec_out, attns, memory_bank, memory_bank_tgt + return dec_out, attns def update_dropout(self, dropout): self.encoder.update_dropout(dropout) diff --git a/onmt/trainer.py b/onmt/trainer.py index 60653f72cc..973a5e1848 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -317,13 +317,21 @@ def validate(self, valid_iter, moving_average=None): tgt = batch.tgt # F-prop through the model. - outputs, attns, enc_src, enc_tgt = valid_model( - src, tgt, src_lengths, - with_align=self.with_align) + if self.encode_tgt: + outputs, attns, enc_src, enc_tgt = valid_model( + src, tgt, src_lengths, + with_align=self.with_align, + encode_tgt=self.encode_tgt) + else: + output, attns = valid_model( + src, tgt, src_lengths, + with_align=self.with_align) + enc_src, enc_tgt = None, None # Compute loss. _, batch_stats = self.valid_loss( - batch, outputs, attns, enc_src, enc_tgt) + batch, outputs, attns, + enc_src=enc_src, enc_tgt=enc_tgt) # Update statistics. stats.update(batch_stats) @@ -366,9 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, if self.accum_count == 1: self.optim.zero_grad() - outputs, attns, enc_src, enc_tgt = self.model( - src, tgt, src_lengths, bptt=bptt, - with_align=self.with_align, encode_tgt=self.encode_tgt) + if self.encode_tgt: + outputs, attns, enc_src, enc_tgt = self.model( + src, tgt, src_lengths, bptt=bptt, + with_align=self.with_align, encode_tgt=self.encode_tgt) + else: + output, attns = self.model( + src, tgt, src_lengths, bptt=bptt, + with_align=self.with_align) + enc_src, enc_tgt = None, None + bptt = True # 3. Compute loss. @@ -377,8 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, batch, outputs, attns, - enc_src, - enc_tgt, + enc_src=enc_src, + enc_tgt=enc_tgt, normalization=normalization, shard_size=self.shard_size, trunc_start=j, diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index bf1abf4caf..f185e1a567 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -124,8 +124,8 @@ def __call__(self, batch, output, attns, - enc_src, - enc_tgt, + enc_src=None, + enc_tgt=None, normalization=1.0, shard_size=0, trunc_start=0, From 75d645ab3f5d1f93fa7faf48a7cf5cfd74c31768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Mon, 17 Feb 2020 18:33:07 +0100 Subject: [PATCH 11/12] roll back tests for model forward --- onmt/tests/test_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onmt/tests/test_models.py b/onmt/tests/test_models.py index 638f4fc84c..76dc5b48b1 100644 --- a/onmt/tests/test_models.py +++ b/onmt/tests/test_models.py @@ -134,7 +134,7 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1): test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) - outputs, attn, _, _ = model(test_src, test_tgt, test_length) + outputs, attn = model(test_src, test_tgt, test_length) outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) @@ -168,7 +168,7 @@ def imagemodel_forward(self, opt, tgt_l=2, bsize=1, h=15, w=17): h=h, w=w, bsize=bsize, tgt_l=tgt_l) - outputs, attn, _, _ = model(test_src, test_tgt, test_length) + outputs, attn = model(test_src, test_tgt, test_length) outputsize = torch.zeros(tgt_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) @@ -206,7 +206,7 @@ def audiomodel_forward(self, opt, tgt_l=7, bsize=3, t=37): sample_rate=opt.sample_rate, window_size=opt.window_size, t=t, tgt_l=tgt_l) - outputs, attn, _, _ = model(test_src, test_tgt, test_length) + outputs, attn = model(test_src, test_tgt, test_length) outputsize = torch.zeros(tgt_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) From f53ea4db2201a702503440b3a313f6e71935b0b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Mon, 17 Feb 2020 19:02:11 +0100 Subject: [PATCH 12/12] fix typo --- onmt/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/trainer.py b/onmt/trainer.py index 973a5e1848..d60c1a0ddb 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -323,7 +323,7 @@ def validate(self, valid_iter, moving_average=None): with_align=self.with_align, encode_tgt=self.encode_tgt) else: - output, attns = valid_model( + outputs, attns = valid_model( src, tgt, src_lengths, with_align=self.with_align) enc_src, enc_tgt = None, None @@ -379,7 +379,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, src, tgt, src_lengths, bptt=bptt, with_align=self.with_align, encode_tgt=self.encode_tgt) else: - output, attns = self.model( + outputs, attns = self.model( src, tgt, src_lengths, bptt=bptt, with_align=self.with_align) enc_src, enc_tgt = None, None