-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecoder.py
102 lines (86 loc) · 4.39 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import tensorflow as tf
import time, os, collections
import data_loader
from wavenet_model import Wavenet_Model
from ops import *
from tensorflow.python.client import timeline
class DECODER():
def __init__(self,args, sess):
self.args = args
self.sess = sess
self.global_step = tf.Variable(0, trainable=False)
# Get test data
with tf.device('/cpu:0'):
print('\tLoading test data')
self.args.num_gpu = 1
test_wave, test_label, test_seq_len = data_loader.get_batches(data_category='test', shuffle=self.args.shuffle, batch_size=self.args.batch_size, num_gpu=self.args.num_gpu, num_threads=1)
self.test_net = Wavenet_Model(self.args, test_wave, test_label, test_seq_len, self.global_step, name='test')
self.test_net.build_model()
# To load checkpoint
self.saver = tf.train.Saver()
self.decode()
def decode(self):
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('test_log', self.sess.graph)
self.sess.run(tf.global_variables_initializer())
if self.load():
print('Load checkpoint')
else:
raise Exception('No ckpt!')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)
try:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
char_prob, decoded, ler_, summary_ = self.sess.run([self.test_net.probs, self.test_net.dcd, self.test_net.ler, merged], options=run_options, run_metadata=run_metadata)
writer.add_run_metadata(run_metadata, 'check')
#print('Label Error rate : %3.4f' % ler_)
decoded_original = reverse_sparse_tensor(decoded[0][0])
# [batch size, number of steps, number of classes]
char_prob = np.asarray(char_prob)
# Get greedy index
high_index = np.argmax(char_prob, axis=2)
str_decoded = list()
for i in range(len(decoded_original)):
str_decoded.append(''.join([chr(x) for x in np.asarray(decoded_original[i]) + SpeechLoader.FIRST_INDEX]))
if self.args.num_classes == 30:
# 27:Space, 28:Apstr, 29:<EOS>, last class:blank
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+4), "")
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+3), '.')
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+2), "'")
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+1), ' ')
elif self.args.num_classes == 29:
# 27:Space, 28:Apstr, last class:blank
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+3), "")
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+2), "'")
str_decoded[i] = str_decoded[i].replace(chr(ord('z')+1), ' ')
#print(str_decoded[i])
options = tf.profiler.ProfileOptionBuilder.time_and_memory()
options["min_bytes"] = 0
options["min_micros"] = 0
options["select"] = ("bytes", "peak_bytes", "output_bytes", "residual_bytes", "micros")
tf.profiler.profile(tf.get_default_graph(), run_meta=run_metadata, cmd="scope", options=options)
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline.json', 'w') as f:
f.write(ctf)
except KeyboardInterrupt:
print('Keyboard')
finally:
coord.request_stop()
coord.join(threads)
@property
def model_dir(self):
return '{}blocks_{}layers_{}width_{}'.format(self.args.num_blocks, self.args.num_wavenet_layers, self.args.filter_width, self.args.dilated_activation)
def load(self):
checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.init_epoch = int(ckpt_name.split('-')[-1])
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
return True
else:
self.init_epoch = 0
return False