forked from tensorflow/nmt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiterator_utils.py
248 lines (214 loc) · 9.37 KB
/
iterator_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For loading data into NMT models."""
from __future__ import print_function
import collections
import tensorflow as tf
from ..utils import vocab_utils
__all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"]
# NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty.
class BatchedInput(
collections.namedtuple("BatchedInput",
("initializer", "source", "target_input",
"target_output", "source_sequence_length",
"target_sequence_length"))):
pass
def get_infer_iterator(src_dataset,
src_vocab_table,
batch_size,
eos,
src_max_len=None,
use_char_encode=False):
if use_char_encode:
src_eos_id = vocab_utils.EOS_CHAR_ID
else:
src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)
if src_max_len:
src_dataset = src_dataset.map(lambda src: src[:src_max_len])
if use_char_encode:
# Convert the word strings to character ids
src_dataset = src_dataset.map(
lambda src: tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]))
else:
# Convert the word strings to ids
src_dataset = src_dataset.map(
lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32))
# Add in the word counts.
if use_char_encode:
src_dataset = src_dataset.map(
lambda src: (src,
tf.to_int32(
tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN)))
else:
src_dataset = src_dataset.map(lambda src: (src, tf.size(src)))
def batching_func(x):
return x.padded_batch(
batch_size,
# The entry is the source line rows;
# this has unknown-length vectors. The last entry is
# the source row size; this is a scalar.
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([])), # src_len
# Pad the source sequences with eos tokens.
# (Though notice we don't generally need to do this since
# later on we will be masking out calculations past the true sequence.
padding_values=(
src_eos_id, # src
0)) # src_len -- unused
batched_dataset = batching_func(src_dataset)
batched_iter = batched_dataset.make_initializable_iterator()
(src_ids, src_seq_len) = batched_iter.get_next()
return BatchedInput(
initializer=batched_iter.initializer,
source=src_ids,
target_input=None,
target_output=None,
source_sequence_length=src_seq_len,
target_sequence_length=None)
def get_iterator(src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size,
sos,
eos,
random_seed,
num_buckets,
src_max_len=None,
tgt_max_len=None,
num_parallel_calls=4,
output_buffer_size=None,
skip_count=None,
num_shards=1,
shard_index=0,
reshuffle_each_iteration=True,
use_char_encode=False):
if not output_buffer_size:
output_buffer_size = batch_size * 1000
if use_char_encode:
src_eos_id = vocab_utils.EOS_CHAR_ID
else:
src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)
src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)
if skip_count is not None:
src_tgt_dataset = src_tgt_dataset.skip(skip_count)
src_tgt_dataset = src_tgt_dataset.shuffle(
output_buffer_size, random_seed, reshuffle_each_iteration)
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (
tf.string_split([src]).values, tf.string_split([tgt]).values),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
# Filter zero length input sequences.
src_tgt_dataset = src_tgt_dataset.filter(
lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))
if src_max_len:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (src[:src_max_len], tgt),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
if tgt_max_len:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (src, tgt[:tgt_max_len]),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
# Convert the word strings to ids. Word strings that are not in the
# vocab get the lookup table's default_value integer.
if use_char_encode:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
num_parallel_calls=num_parallel_calls)
else:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
num_parallel_calls=num_parallel_calls)
src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size)
# Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (src,
tf.concat(([tgt_sos_id], tgt), 0),
tf.concat((tgt, [tgt_eos_id]), 0)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
# Add in sequence lengths.
if use_char_encode:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt_in, tgt_out: (
src, tgt_in, tgt_out,
tf.to_int32(tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN),
tf.size(tgt_in)),
num_parallel_calls=num_parallel_calls)
else:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt_in, tgt_out: (
src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
num_parallel_calls=num_parallel_calls)
src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size)
# Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...)
def batching_func(x):
return x.padded_batch(
batch_size,
# The first three entries are the source and target line rows;
# these have unknown-length vectors. The last two entries are
# the source and target row sizes; these are scalars.
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([None]), # tgt_input
tf.TensorShape([None]), # tgt_output
tf.TensorShape([]), # src_len
tf.TensorShape([])), # tgt_len
# Pad the source and target sequences with eos tokens.
# (Though notice we don't generally need to do this since
# later on we will be masking out calculations past the true sequence.
padding_values=(
src_eos_id, # src
tgt_eos_id, # tgt_input
tgt_eos_id, # tgt_output
0, # src_len -- unused
0)) # tgt_len -- unused
if num_buckets > 1:
def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
# Calculate bucket_width by maximum source sequence length.
# Pairs with length [0, bucket_width) go to bucket 0, length
# [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length
# over ((num_bucket-1) * bucket_width) words all go into the last bucket.
if src_max_len:
bucket_width = (src_max_len + num_buckets - 1) // num_buckets
else:
bucket_width = 10
# Bucket sentence pairs by the length of their source sentence and target
# sentence.
bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
return tf.to_int64(tf.minimum(num_buckets, bucket_id))
def reduce_func(unused_key, windowed_data):
return batching_func(windowed_data)
batched_dataset = src_tgt_dataset.apply(
tf.contrib.data.group_by_window(
key_func=key_func, reduce_func=reduce_func, window_size=batch_size))
else:
batched_dataset = batching_func(src_tgt_dataset)
batched_iter = batched_dataset.make_initializable_iterator()
(src_ids, tgt_input_ids, tgt_output_ids, src_seq_len,
tgt_seq_len) = (batched_iter.get_next())
return BatchedInput(
initializer=batched_iter.initializer,
source=src_ids,
target_input=tgt_input_ids,
target_output=tgt_output_ids,
source_sequence_length=src_seq_len,
target_sequence_length=tgt_seq_len)