-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdecoding.py
92 lines (74 loc) · 3.45 KB
/
decoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#Copyright 2022 Hamidreza Sadeghi. All rights reserved.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
def decode_sequence_lstm_encoder_decoder(input_seq, encoder_model, decoder_model, bpet):
# Encode the input as state vectors.
ouput_values, states_values = encoder_model.predict(input_seq)
#print(states_value)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1),dtype='float32')
# Populate the first character of target sequence with the start character.
target_seq[0, 0] = bpet.bpe2idx['__she']
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
decoded_sentence = []
while 1:
output_tokens,h1= decoder_model.predict(
[target_seq, ouput_values, states_values])
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_word = bpet.idx2bpe[sampled_token_index]
# Exit condition: either hit max length
# or find stop character.
if (sampled_word == '__ehe2' or sampled_word == '__ehe1' or
len(decoded_sentence) > 100):
break
decoded_sentence += [sampled_token_index]
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1),dtype='float32')
target_seq[0, 0] = sampled_token_index
# Update states
states_values = h1
return decoded_sentence
def decode_batch_lstm_encoder_decoder(X_input, encoder_model, decoder_model, bpet):
# Encode the input as state vectors.
ouput_values, states_values = encoder_model.predict(X_input)
#print(states_value)
# Generate empty target sequence of length 1.
target_seq = np.zeros((len(X_input), 1),dtype='float32')
# Populate the first character of target sequence with the start character.
target_seq[:, 0] = bpet.bpe2idx['__she']
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
decoded_sentences = [[] for i in range(len(X_input))]
counter = 0
while 1:
output_tokens,h1= decoder_model.predict(
[target_seq, ouput_values, states_values])
# Sample a token
sampled_token_index = np.argmax(output_tokens[:,:, :], axis=2)
sampled_words = [(x[-1], bpet.idx2bpe[x[-1]]) for x in sampled_token_index]
acceptable_tokens = [i for i,x in enumerate(sampled_words) if x[1] != '__ehe2' and x[1] != '__ehe1']
# Exit condition: either hit max length
# or find stop character.
if (len(acceptable_tokens) == 0 or counter > 40):
break
for i in range(len(decoded_sentences)):
if sampled_words[i][1] != '__ehe2' and sampled_words[i][1] != '__ehe1':
decoded_sentences[i].append(sampled_words[i][0])
target_seq = sampled_token_index[:,:]
# Update states
states_values = h1
counter += 1
return decoded_sentences