Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP v1 - deprecated] Add learning of target features #1710

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions onmt/bin/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,9 @@ def preprocess(opt):

src_nfeats = 0
tgt_nfeats = 0
for src, tgt in zip(opt.train_src, opt.train_tgt):
src_nfeats += count_features(src) if opt.data_type == 'text' \
else 0
tgt_nfeats += count_features(tgt) # tgt always text so far
src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \
else 0
tgt_nfeats = count_features(opt.train_tgt[0]) # tgt always text so far
logger.info(" * number of source features: %d." % src_nfeats)
logger.info(" * number of target features: %d." % tgt_nfeats)

Expand Down
51 changes: 41 additions & 10 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,20 +582,44 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple,
yield b


class OrderedIterator(torchtext.data.Iterator):
class OnmtBatch(torchtext.data.Batch):
def __init__(self, data=None, dataset=None,
device=None, feat_no_time_shift=False):
super(OnmtBatch, self).__init__(data, dataset, device)
# we need to shift target features if needed
if not(feat_no_time_shift):
if hasattr(self, 'tgt') and self.tgt.size(-1) > 1:
# tokens: [ len x batch x 1]
tokens = self.tgt[:, :, 0].unsqueeze(-1)
# feats: [ len x batch x num_feats ]
feats = self.tgt[:, :, 1:]
# shift feats one step to the right
feats = torch.cat((
feats[-1, :, :].unsqueeze(0),
feats[:-1, :, :]
))
# build back target tensor
self.tgt = torch.cat((
tokens,
feats
), dim=-1)


class OrderedIterator(torchtext.data.Iterator):
def __init__(self,
dataset,
batch_size,
pool_factor=1,
batch_size_multiple=1,
yield_raw_example=False,
feat_no_time_shift=False,
**kwargs):
super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs)
self.batch_size_multiple = batch_size_multiple
self.yield_raw_example = yield_raw_example
self.dataset = dataset
self.pool_factor = pool_factor
self.feat_no_time_shift = feat_no_time_shift

def create_batches(self):
if self.train:
Expand Down Expand Up @@ -627,7 +651,7 @@ def __iter__(self):
"""
Extended version of the definition in torchtext.data.Iterator.
Added yield_raw_example behaviour to yield a torchtext.data.Example
instead of a torchtext.data.Batch object.
instead of an OnmtBatch object.
"""
while True:
self.init_epoch()
Expand All @@ -648,10 +672,11 @@ def __iter__(self):
if self.yield_raw_example:
yield minibatch[0]
else:
yield torchtext.data.Batch(
yield OnmtBatch(
minibatch,
self.dataset,
self.device)
self.device,
feat_no_time_shift=self.feat_no_time_shift)
if not self.repeat:
return

Expand Down Expand Up @@ -683,6 +708,7 @@ def __init__(self,
self.sort_key = temp_dataset.sort_key
self.random_shuffler = RandomShuffler()
self.pool_factor = opt.pool_factor
self.feat_no_time_shift = opt.feat_no_time_shift
del temp_dataset

def _iter_datasets(self):
Expand All @@ -709,9 +735,10 @@ def __iter__(self):
self.random_shuffler,
self.pool_factor):
minibatch = sorted(minibatch, key=self.sort_key, reverse=True)
yield torchtext.data.Batch(minibatch,
self.iterables[0].dataset,
self.device)
yield OnmtBatch(minibatch,
self.iterables[0].dataset,
self.device,
feat_no_time_shift=self.feat_no_time_shift)


class DatasetLazyIter(object):
Expand All @@ -729,7 +756,8 @@ class DatasetLazyIter(object):

def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
batch_size_multiple, device, is_train, pool_factor,
repeat=True, num_batches_multiple=1, yield_raw_example=False):
repeat=True, num_batches_multiple=1, feat_no_time_shift=False,
yield_raw_example=False):
self._paths = dataset_paths
self.fields = fields
self.batch_size = batch_size
Expand All @@ -741,6 +769,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
self.num_batches_multiple = num_batches_multiple
self.yield_raw_example = yield_raw_example
self.pool_factor = pool_factor
self.feat_no_time_shift = feat_no_time_shift

def _iter_dataset(self, path):
logger.info('Loading dataset from %s' % path)
Expand All @@ -758,7 +787,8 @@ def _iter_dataset(self, path):
sort=False,
sort_within_batch=True,
repeat=False,
yield_raw_example=self.yield_raw_example
yield_raw_example=self.yield_raw_example,
feat_no_time_shift=self.feat_no_time_shift
)
for batch in cur_iter:
self.dataset = cur_iter.dataset
Expand Down Expand Up @@ -852,7 +882,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
opt.pool_factor,
repeat=not opt.single_pass,
num_batches_multiple=max(opt.accum_count) * opt.world_size,
yield_raw_example=multi)
yield_raw_example=multi,
feat_no_time_shift=opt.feat_no_time_shift)


def build_dataset_iter_multiple(train_shards, fields, opt):
Expand Down
53 changes: 31 additions & 22 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from onmt.decoders import str2dec

from onmt.modules import Embeddings, VecEmbedding, CopyGenerator
from onmt.modules.util_class import Cast
from onmt.modules import Embeddings, VecEmbedding, Generator
from onmt.utils.misc import use_gpu
from onmt.utils.logging import logger
from onmt.utils.parse import ArgumentParser
Expand Down Expand Up @@ -88,6 +87,35 @@ def build_decoder(opt, embeddings):
return str2dec[dec_type].from_opt(opt, embeddings)


def build_generator(model_opt, fields, decoder):
gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields]
if model_opt.share_decoder_embeddings:
rnn_sizes = ([model_opt.rnn_size -
(model_opt.feat_vec_size * (len(gen_sizes) - 1))]
+ [model_opt.feat_vec_size] * (len(gen_sizes) - 1))
else:
rnn_sizes = [model_opt.rnn_size] * len(gen_sizes)

if model_opt.generator_function == "sparsemax":
gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
else:
gen_func = nn.LogSoftmax(dim=-1)

tgt_base_field = fields["tgt"].base_field
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
generator = Generator(rnn_sizes, gen_sizes, gen_func,
shared=model_opt.share_decoder_embeddings,
copy_attn=model_opt.copy_attn,
pad_idx=pad_idx)

if model_opt.share_decoder_embeddings:
# share the weights
for gen, emb in zip(generator.generators, decoder.embeddings.emb_luts):
gen[0].weight = emb.weight

return generator


def load_test_model(opt, model_path=None):
if model_path is None:
model_path = opt.models[0]
Expand Down Expand Up @@ -172,26 +200,7 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
model = onmt.models.NMTModel(encoder, decoder)

# Build Generator.
if not model_opt.copy_attn:
if model_opt.generator_function == "sparsemax":
gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
else:
gen_func = nn.LogSoftmax(dim=-1)
generator = nn.Sequential(
nn.Linear(model_opt.dec_rnn_size,
len(fields["tgt"].base_field.vocab)),
Cast(torch.float32),
gen_func
)
if model_opt.share_decoder_embeddings:
generator[0].weight = decoder.embeddings.word_lut.weight
else:
tgt_base_field = fields["tgt"].base_field
vocab_size = len(tgt_base_field.vocab)
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)
if model_opt.share_decoder_embeddings:
generator.linear.weight = decoder.embeddings.word_lut.weight
generator = build_generator(model_opt, fields, decoder)

# Load the model states from checkpoint or initialize them.
if checkpoint is not None:
Expand Down
3 changes: 2 additions & 1 deletion onmt/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from onmt.modules.gate import context_gate_factory, ContextGate
from onmt.modules.global_attention import GlobalAttention
from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention
from onmt.modules.generator import Generator
from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \
CopyGeneratorLossCompute
from onmt.modules.multi_headed_attn import MultiHeadedAttention
Expand All @@ -13,6 +14,6 @@

__all__ = ["Elementwise", "context_gate_factory", "ContextGate",
"GlobalAttention", "ConvMultiStepAttention", "CopyGenerator",
"CopyGeneratorLoss", "CopyGeneratorLossCompute",
"Generator", "CopyGeneratorLoss", "CopyGeneratorLossCompute",
"MultiHeadedAttention", "Embeddings", "PositionalEncoding",
"WeightNormConv2d", "AverageAttention", "VecEmbedding"]
53 changes: 53 additions & 0 deletions onmt/modules/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
""" Onmt NMT Model base class definition """
import torch
import torch.nn as nn

from onmt.modules.util_class import Cast

from onmt.modules.copy_generator import CopyGenerator


class Generator(nn.Module):
def __init__(self, rnn_sizes, gen_sizes, gen_func,
shared=False, copy_attn=False, pad_idx=None):
super(Generator, self).__init__()
self.generators = nn.ModuleList()
self.shared = shared
self.rnn_sizes = rnn_sizes
self.gen_sizes = gen_sizes

def simple_generator(rnn_size, gen_size, gen_func):
return nn.Sequential(
nn.Linear(rnn_size, gen_size),
Cast(torch.float32),
gen_func)

# create first generator
if copy_attn:
self.generators.append(
CopyGenerator(rnn_sizes[0], gen_sizes[0], pad_idx))
else:
self.generators.append(
simple_generator(rnn_sizes[0], gen_sizes[0], gen_func))

# additional generators for features
for rnn_size, gen_size in zip(rnn_sizes[1:], gen_sizes[1:]):
self.generators.append(
simple_generator(rnn_size, gen_size, gen_func))

def forward(self, dec_out):
# if shared_decoder_embeddings, we slice the decoder output
if self.shared:
outs = []
offset = 0
for generator, s in zip(self.generators, self.rnn_sizes):
sliced_dec_out = dec_out[:, offset:offset+s]
out = generator(sliced_dec_out)
offset += s
outs.append(out)
return outs
else:
return [generator(dec_out) for generator in self.generators]

def __getitem__(self, i):
return self.generators[0][i]
4 changes: 4 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def model_opts(parser):
help="If -feat_merge_size is not set, feature "
"embedding sizes will be set to N^feat_vec_exponent "
"where N is the number of values the feature takes.")
group.add('--feat_no_time_shift', '-feat_no_time_shift',
action='store_true',
help="If set, do not shift the target features one step "
"to the right.")

# Encoder-Decoder Options
group = parser.add_argument_group('Model- Encoder-Decoder')
Expand Down
6 changes: 3 additions & 3 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
used to save the model
"""

tgt_field = dict(fields)["tgt"].base_field
train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt)
tgt_fields = dict(fields)["tgt"]
train_loss = onmt.utils.loss.build_loss_compute(model, tgt_fields, opt)
valid_loss = onmt.utils.loss.build_loss_compute(
model, tgt_field, opt, train=False)
model, tgt_fields, opt, train=False)

trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0
Expand Down
Loading