-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
executable file
·69 lines (51 loc) · 3.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from tensorflow.python.keras.layers import Input, GRU, Dense, Concatenate, TimeDistributed
from tensorflow.python.keras.models import Model
from attention import AttentionLayer
def summary_model(hidden_size, batch_size, ip_timesteps, ip_vsize, op_timesteps, op_vsize):
""" Defining a summary model """
# Define an input sequence and process it.
if batch_size:
encoder_inputs = Input(batch_shape=(batch_size, ip_timesteps, ip_vsize), name='encoder_inputs')
decoder_inputs = Input(batch_shape=(batch_size, op_timesteps - 1, op_vsize), name='decoder_inputs')
else:
encoder_inputs = Input(shape=(ip_timesteps, ip_vsize), name='encoder_inputs')
decoder_inputs = Input(shape=(op_timesteps - 1, op_vsize), name='decoder_inputs')
# Encoder
encoder_gru = GRU(hidden_size, return_sequences=True, return_state=True, name='encoder_gru')
encoder_out, encoder_state = encoder_gru(encoder_inputs)
# Set up the decoder, using `encoder_states` as initial state.
decoder_gru = GRU(hidden_size, return_sequences=True, return_state=True, name='decoder_gru')
decoder_out, decoder_state = decoder_gru(decoder_inputs, initial_state=encoder_state)
# Attention layer
attn_layer = AttentionLayer(name='attention_layer')
attn_out, attn_states = attn_layer([encoder_out, decoder_out])
# Concat attention input and decoder output
decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_out, attn_out])
# Dense layer
dense = Dense(op_vsize, activation='softmax', name='softmax_layer')
dense_time = TimeDistributed(dense, name='time_distributed_layer')
decoder_pred = dense_time(decoder_concat_input)
# Full model
full_model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=decoder_pred)
full_model.compile(optimizer='adam', loss='categorical_crossentropy')
#full_model.summary()
""" Inference model """
batch_size = 1
""" Encoder (Inference) model """
encoder_inf_inputs = Input(batch_shape=(batch_size, ip_timesteps, ip_vsize), name='encoder_inf_inputs')
encoder_inf_out, encoder_inf_state = encoder_gru(encoder_inf_inputs)
encoder_model = Model(inputs=encoder_inf_inputs, outputs=[encoder_inf_out, encoder_inf_state])
""" Decoder (Inference) model """
decoder_inf_inputs = Input(batch_shape=(batch_size, 1, op_vsize), name='decoder_word_inputs')
encoder_inf_states = Input(batch_shape=(batch_size, ip_timesteps, hidden_size), name='encoder_inf_states')
decoder_init_state = Input(batch_shape=(batch_size, hidden_size), name='decoder_init')
decoder_inf_out, decoder_inf_state = decoder_gru(decoder_inf_inputs, initial_state=decoder_init_state)
attn_inf_out, attn_inf_states = attn_layer([encoder_inf_states, decoder_inf_out])
decoder_inf_concat = Concatenate(axis=-1, name='concat')([decoder_inf_out, attn_inf_out])
decoder_inf_pred = TimeDistributed(dense)(decoder_inf_concat)
decoder_model = Model(inputs=[encoder_inf_states, decoder_init_state, decoder_inf_inputs],
outputs=[decoder_inf_pred, attn_inf_states, decoder_inf_state])
return full_model, encoder_model, decoder_model
if __name__ == '__main__':
""" Checking summary model for toy examples """
summary_model(64, None, 20, 30, 20, 20)