-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathskip_thoughts_model.py
369 lines (307 loc) · 12.4 KB
/
skip_thoughts_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Skip-Thoughts model for learning sentence vectors.
The model is based on the paper:
"Skip-Thought Vectors"
Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
Antonio Torralba, Raquel Urtasun, Sanja Fidler.
https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf
Layer normalization is applied based on the paper:
"Layer Normalization"
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
https://arxiv.org/abs/1607.06450
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from skip_thoughts.ops import gru_cell
from skip_thoughts.ops import input_ops
def random_orthonormal_initializer(shape, dtype=tf.float32,
partition_info=None): # pylint: disable=unused-argument
"""Variable initializer that produces a random orthonormal matrix."""
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError("Expecting square shape, got %s" % shape)
_, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
return u
class SkipThoughtsModel(object):
"""Skip-thoughts model."""
def __init__(self, config, mode="train", input_reader=None):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
config: Object containing configuration parameters.
mode: "train", "eval" or "encode".
input_reader: Subclass of tf.ReaderBase for reading the input serialized
tf.Example protocol buffers. Defaults to TFRecordReader.
Raises:
ValueError: If mode is invalid.
"""
if mode not in ["train", "eval", "encode"]:
raise ValueError("Unrecognized mode: %s" % mode)
self.config = config
self.mode = mode
self.reader = input_reader if input_reader else tf.TFRecordReader()
# Initializer used for non-recurrent weights.
self.uniform_initializer = tf.random_uniform_initializer(
minval=-self.config.uniform_init_scale,
maxval=self.config.uniform_init_scale)
# Input sentences represented as sequences of word ids. "encode" is the
# source sentence, "decode_pre" is the previous sentence and "decode_post"
# is the next sentence.
# Each is an int64 Tensor with shape [batch_size, padded_length].
self.encode_ids = None
self.decode_pre_ids = None
self.decode_post_ids = None
# Boolean masks distinguishing real words (1) from padded words (0).
# Each is an int32 Tensor with shape [batch_size, padded_length].
self.encode_mask = None
self.decode_pre_mask = None
self.decode_post_mask = None
# Input sentences represented as sequences of word embeddings.
# Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim].
self.encode_emb = None
self.decode_pre_emb = None
self.decode_post_emb = None
# The output from the sentence encoder.
# A float32 Tensor with shape [batch_size, num_gru_units].
self.thought_vectors = None
# The cross entropy losses and corresponding weights of the decoders. Used
# for evaluation.
self.target_cross_entropy_losses = []
self.target_cross_entropy_loss_weights = []
# The total loss to optimize.
self.total_loss = None
def build_inputs(self):
"""Builds the ops for reading input data.
Outputs:
self.encode_ids
self.decode_pre_ids
self.decode_post_ids
self.encode_mask
self.decode_pre_mask
self.decode_post_mask
"""
if self.mode == "encode":
# Word embeddings are fed from an external vocabulary which has possibly
# been expanded (see vocabulary_expansion.py).
encode_ids = None
decode_pre_ids = None
decode_post_ids = None
encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
decode_pre_mask = None
decode_post_mask = None
else:
# Prefetch serialized tf.Example protos.
input_queue = input_ops.prefetch_input_data(
self.reader,
self.config.input_file_pattern,
shuffle=self.config.shuffle_input_data,
capacity=self.config.input_queue_capacity,
num_reader_threads=self.config.num_input_reader_threads)
# Deserialize a batch.
serialized = input_queue.dequeue_many(self.config.batch_size)
encode, decode_pre, decode_post = input_ops.parse_example_batch(
serialized)
encode_ids = encode.ids
decode_pre_ids = decode_pre.ids
decode_post_ids = decode_post.ids
encode_mask = encode.mask
decode_pre_mask = decode_pre.mask
decode_post_mask = decode_post.mask
self.encode_ids = encode_ids
self.decode_pre_ids = decode_pre_ids
self.decode_post_ids = decode_post_ids
self.encode_mask = encode_mask
self.decode_pre_mask = decode_pre_mask
self.decode_post_mask = decode_post_mask
def build_word_embeddings(self):
"""Builds the word embeddings.
Inputs:
self.encode_ids
self.decode_pre_ids
self.decode_post_ids
Outputs:
self.encode_emb
self.decode_pre_emb
self.decode_post_emb
"""
if self.mode == "encode":
# Word embeddings are fed from an external vocabulary which has possibly
# been expanded (see vocabulary_expansion.py).
encode_emb = tf.placeholder(tf.float32, (
None, None, self.config.word_embedding_dim), "encode_emb")
# No sequences to decode.
decode_pre_emb = None
decode_post_emb = None
else:
word_emb = tf.get_variable(
name="word_embedding",
shape=[self.config.vocab_size, self.config.word_embedding_dim],
initializer=self.uniform_initializer)
encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids)
decode_pre_emb = tf.nn.embedding_lookup(word_emb, self.decode_pre_ids)
decode_post_emb = tf.nn.embedding_lookup(word_emb, self.decode_post_ids)
self.encode_emb = encode_emb
self.decode_pre_emb = decode_pre_emb
self.decode_post_emb = decode_post_emb
def _initialize_gru_cell(self, num_units):
"""Initializes a GRU cell.
The Variables of the GRU cell are initialized in a way that exactly matches
the skip-thoughts paper: recurrent weights are initialized from random
orthonormal matrices and non-recurrent weights are initialized from random
uniform matrices.
Args:
num_units: Number of output units.
Returns:
cell: An instance of RNNCell with variable initializers that match the
skip-thoughts paper.
"""
return gru_cell.LayerNormGRUCell(
num_units,
w_initializer=self.uniform_initializer,
u_initializer=random_orthonormal_initializer,
b_initializer=tf.constant_initializer(0.0))
def build_encoder(self):
"""Builds the sentence encoder.
Inputs:
self.encode_emb
self.encode_mask
Outputs:
self.thought_vectors
Raises:
ValueError: if config.bidirectional_encoder is True and config.encoder_dim
is odd.
"""
with tf.variable_scope("encoder") as scope:
length = tf.to_int32(tf.reduce_sum(self.encode_mask, 1), name="length")
if self.config.bidirectional_encoder:
if self.config.encoder_dim % 2:
raise ValueError(
"encoder_dim must be even when using a bidirectional encoder.")
num_units = self.config.encoder_dim // 2
cell_fw = self._initialize_gru_cell(num_units) # Forward encoder
cell_bw = self._initialize_gru_cell(num_units) # Backward encoder
_, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=self.encode_emb,
sequence_length=length,
dtype=tf.float32,
scope=scope)
thought_vectors = tf.concat(states, 1, name="thought_vectors")
else:
cell = self._initialize_gru_cell(self.config.encoder_dim)
_, state = tf.nn.dynamic_rnn(
cell=cell,
inputs=self.encode_emb,
sequence_length=length,
dtype=tf.float32,
scope=scope)
# Use an identity operation to name the Tensor in the Graph.
thought_vectors = tf.identity(state, name="thought_vectors")
self.thought_vectors = thought_vectors
def _build_decoder(self, name, embeddings, targets, mask, initial_state,
reuse_logits):
"""Builds a sentence decoder.
Args:
name: Decoder name.
embeddings: Batch of sentences to decode; a float32 Tensor with shape
[batch_size, padded_length, emb_dim].
targets: Batch of target word ids; an int64 Tensor with shape
[batch_size, padded_length].
mask: A 0/1 Tensor with shape [batch_size, padded_length].
initial_state: Initial state of the GRU. A float32 Tensor with shape
[batch_size, num_gru_cells].
reuse_logits: Whether to reuse the logits weights.
"""
# Decoder RNN.
cell = self._initialize_gru_cell(self.config.encoder_dim)
with tf.variable_scope(name) as scope:
# Add a padding word at the start of each sentence (to correspond to the
# prediction of the first word) and remove the last word.
decoder_input = tf.pad(
embeddings[:, :-1, :], [[0, 0], [1, 0], [0, 0]], name="input")
length = tf.reduce_sum(mask, 1, name="length")
decoder_output, _ = tf.nn.dynamic_rnn(
cell=cell,
inputs=decoder_input,
sequence_length=length,
initial_state=initial_state,
scope=scope)
# Stack batch vertically.
decoder_output = tf.reshape(decoder_output, [-1, self.config.encoder_dim])
targets = tf.reshape(targets, [-1])
weights = tf.to_float(tf.reshape(mask, [-1]))
# Logits.
with tf.variable_scope("logits", reuse=reuse_logits) as scope:
logits = tf.contrib.layers.fully_connected(
inputs=decoder_output,
num_outputs=self.config.vocab_size,
activation_fn=None,
weights_initializer=self.uniform_initializer,
scope=scope)
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=logits)
batch_loss = tf.reduce_sum(losses * weights)
tf.losses.add_loss(batch_loss)
tf.summary.scalar("losses/" + name, batch_loss)
self.target_cross_entropy_losses.append(losses)
self.target_cross_entropy_loss_weights.append(weights)
def build_decoders(self):
"""Builds the sentence decoders.
Inputs:
self.decode_pre_emb
self.decode_post_emb
self.decode_pre_ids
self.decode_post_ids
self.decode_pre_mask
self.decode_post_mask
self.thought_vectors
Outputs:
self.target_cross_entropy_losses
self.target_cross_entropy_loss_weights
"""
if self.mode != "encode":
# Pre-sentence decoder.
self._build_decoder("decoder_pre", self.decode_pre_emb,
self.decode_pre_ids, self.decode_pre_mask,
self.thought_vectors, False)
# Post-sentence decoder. Logits weights are reused.
self._build_decoder("decoder_post", self.decode_post_emb,
self.decode_post_ids, self.decode_post_mask,
self.thought_vectors, True)
def build_loss(self):
"""Builds the loss Tensor.
Outputs:
self.total_loss
"""
if self.mode != "encode":
total_loss = tf.losses.get_total_loss()
tf.summary.scalar("losses/total", total_loss)
self.total_loss = total_loss
def build_global_step(self):
"""Builds the global step Tensor.
Outputs:
self.global_step
"""
self.global_step = tf.contrib.framework.create_global_step()
def build(self):
"""Creates all ops for training, evaluation or encoding."""
self.build_inputs()
self.build_word_embeddings()
self.build_encoder()
self.build_decoders()
self.build_loss()
#self.build_global_step()