25
25
# from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
26
26
27
27
# from intel_quantization.quantize_graph import GraphRewriter
28
- from intel_quantization .transform_graph .strip_unused import StripUnusedNodes
29
- from intel_quantization .transform_graph .fold_batch_norm import FoldBatchNormNodes
30
- from intel_quantization .transform_graph .insert_logging import InsertLogging
31
- from intel_quantization .transform_graph .freeze_max_min import freeze_max
32
- from intel_quantization .transform_graph .freeze_max_min import freeze_min
33
- from intel_quantization .transform_graph .freeze_max_min import freeze_requantization_range
34
- from intel_quantization .transform_graph .fuse_quantized_conv_and_requantize import fuse_quantized_conv_and_requantize
35
- from intel_quantization .transform_graph .fuse_column_wise_mul import FuseColumnWiseMul
36
- from intel_quantization .transform_graph .rerange_quantized_concat import RerangeQuantizedConcat
37
- from intel_quantization .util import read_graph , write_graph
38
- from intel_quantization .quantize_graph .quantize_graph_for_intel_cpu import QuantizeGraphForIntel
28
+ from .transform_graph .strip_unused import StripUnusedNodes
29
+ from .transform_graph .fold_batch_norm import FoldBatchNormNodes
30
+ from .transform_graph .insert_logging import InsertLogging
31
+ from .transform_graph .freeze_max_min import freeze_max
32
+ from .transform_graph .freeze_max_min import freeze_min
33
+ from .transform_graph .freeze_max_min import freeze_requantization_range
34
+ from .transform_graph .freeze_max_min import get_all_fp32_data , get_tensor_histogram , combine_histogram
35
+ from .transform_graph .fuse_quantized_conv_and_requantize import fuse_quantized_conv_and_requantize
36
+ from .transform_graph .fuse_column_wise_mul import FuseColumnWiseMul
37
+ from .transform_graph .rerange_quantized_concat import RerangeQuantizedConcat
38
+ from .util import read_graph , write_graph
39
+ from .quantize_graph .quantize_graph_for_intel_cpu import QuantizeGraphForIntel
40
+ from .quantize_graph .quantize_graph_common import QuantizeGraphHelper
39
41
import os
40
42
import shlex
41
43
import subprocess
42
- import sys
43
44
import logging
44
45
45
46
logging .getLogger ().setLevel (level = logging .INFO )
52
53
53
54
class GraphConverter :
54
55
def __init__ (self , input_graph , output_graph , inputs = [], outputs = [], excluded_ops = [], excluded_nodes = [],
55
- per_channel = False , input_graph_is_binary = True ):
56
+ per_channel = False , input_graph_is_binary = True , algo = 'DIRECT' ):
56
57
"""Convert graph.
57
58
58
59
:param input_graph: input graph pb file.
@@ -73,13 +74,18 @@ def __init__(self, input_graph, output_graph, inputs=[], outputs=[], excluded_op
73
74
self .per_channel = per_channel
74
75
self .excluded_ops = excluded_ops
75
76
self .excluded_nodes = excluded_nodes
77
+ self .algo = algo
76
78
self ._low_precision_mode = 'eightbit'
77
-
79
+ self ._calibration_data = []
80
+ self ._fp32_print_data = []
78
81
self .gen_calib_data_cmds = None
79
82
self .debug = False
80
83
self ._check_tf_version ()
81
84
self ._check_args ()
82
85
self ._gen_tmp_filenames ()
86
+ self ._kl_op_dict = {}
87
+ self ._kl_keys = []
88
+ self ._print_node_mapping = {}
83
89
84
90
def _check_tf_version (self ):
85
91
is_supported_version = False
@@ -113,7 +119,7 @@ def _gen_tmp_filenames(self):
113
119
self ._fp32_optimized_graph = os .path .join (self ._output_path , 'fp32_optimized_graph.pb' )
114
120
self ._int8_dynamic_range_graph = os .path .join (self ._output_path , 'int8_dynamic_range_graph.pb' )
115
121
self ._int8_logged_graph = os .path .join (self ._output_path , 'int8_logged_graph.pb' )
116
- self ._requant_min_max_log = os .path .join (self ._output_path , 'requant_min_max_log.txt ' )
122
+ self ._fp32_logged_graph = os .path .join (self ._output_path , 'fp32_logged_graph.pb ' )
117
123
self ._int8_frozen_range_graph = os .path .join (self ._output_path , 'int8_frozen_range_graph.pb' )
118
124
if not self .output_graph :
119
125
self .output_graph = os .path .join (self ._output_path , 'int8_final_fused_graph.pb' )
@@ -137,6 +143,58 @@ def convert(self):
137
143
else :
138
144
self .quantize ()
139
145
146
+ def _get_fp32_print_node_names (self ):
147
+ offset_map = {
148
+ "QuantizedConv2DWithBiasSumAndRelu" : 3 ,
149
+ "QuantizedConv2DWithBiasAndRelu" : 2 ,
150
+ "QuantizedConv2DWithBias" : 1 ,
151
+ }
152
+ target_conv_op = []
153
+ sorted_graph = QuantizeGraphHelper ().get_sorted_graph (
154
+ self ._fp32_origin_graph , self .outputs )
155
+
156
+ node_name_mapping = {
157
+ node .name : node
158
+ for node in self ._tmp_graph_def .node if node .op != "Const"
159
+ }
160
+
161
+ for node in self ._tmp_graph_def .node :
162
+ if node .op in offset_map :
163
+ target_conv_op .append (node .name .split ('_eightbit_' )[0 ])
164
+ fp32_node_name_mapping = {
165
+ node .name : node
166
+ for node in sorted_graph .node if node .op != "Const"
167
+ }
168
+ sorted_node_names = [i .name for i in sorted_graph .node if i .op != "Const" ]
169
+
170
+ output_node_names = []
171
+ for i in target_conv_op :
172
+ if node_name_mapping [
173
+ i + "_eightbit_quantized_conv" ].op == 'QuantizedConv2DWithBiasSumAndRelu' :
174
+ start_index = sorted_node_names .index (i )
175
+ for index , value in enumerate (sorted_node_names [start_index :]):
176
+ if fp32_node_name_mapping [value ].op .startswith (
177
+ "Add" ) and fp32_node_name_mapping [
178
+ sorted_node_names [start_index + index + 1 ]].op == "Relu" :
179
+ output_node_names .append (
180
+ sorted_node_names [start_index + index + 1 ])
181
+ self ._print_node_mapping [sorted_node_names [start_index + index + 1 ]] = i
182
+ elif i in sorted_node_names :
183
+ start_index = sorted_node_names .index (i )
184
+ end_index = start_index + offset_map [node_name_mapping [
185
+ i + "_eightbit_quantized_conv" ].op ]
186
+ output_node_names .append (sorted_node_names [end_index ])
187
+ self ._print_node_mapping [sorted_node_names [end_index ]] = i
188
+
189
+ for i in output_node_names :
190
+ self ._kl_keys .append (';' + i + '__print__;__KL' )
191
+
192
+ InsertLogging (self ._fp32_origin_graph ,
193
+ node_name_list = output_node_names ,
194
+ message = "__KL:" ,
195
+ summarize = - 1 , dump_fp32 = True ).do_transformation ()
196
+ write_graph (self ._fp32_origin_graph , self ._fp32_logged_graph )
197
+
140
198
def quantize (self ):
141
199
"""Quantize graph only (without optimizing fp32 graph), including:
142
200
1) quantize graph,
@@ -150,9 +208,14 @@ def quantize(self):
150
208
'to generate calibration data.' )
151
209
try :
152
210
self ._quantize_graph ()
211
+ if self .algo == "KL" :
212
+ self ._get_fp32_print_node_names ()
213
+ self ._generate_calibration_data (self ._fp32_logged_graph ,
214
+ self ._fp32_print_data , True )
215
+
153
216
self ._insert_logging ()
154
- self ._generate_calibration_data ()
155
- self ._freeze_requantization_ranges ()
217
+ self ._generate_calibration_data (self . _int8_logged_graph , self . _calibration_data )
218
+ self ._freeze_requantization_ranges (self . _kl_op_dict )
156
219
self ._fuse_requantize_with_fused_quantized_conv ()
157
220
except Exception as e :
158
221
logging .error ('Failed to quantize graph due to: %s' , str (e ))
@@ -172,6 +235,7 @@ def _optimize_frozen_fp32_graph(self):
172
235
self ._tmp_graph_def = graph_util .remove_training_nodes (self ._tmp_graph_def , self .outputs )
173
236
self ._tmp_graph_def = FoldBatchNormNodes (self ._tmp_graph_def ).do_transform ()
174
237
write_graph (self ._tmp_graph_def , self ._fp32_optimized_graph )
238
+ self ._fp32_origin_graph = self ._tmp_graph_def
175
239
176
240
def _quantize_graph (self ):
177
241
"""quantize graph."""
@@ -199,32 +263,50 @@ def _insert_logging(self):
199
263
ops = ["RequantizationRange{}" .format ("PerChannel" if self .per_channel else "" )],
200
264
message = "__requant_min_max:" ).do_transformation ()
201
265
InsertLogging (self ._tmp_graph_def , ops = ["Min" ], message = "__min:" ).do_transformation ()
202
- InsertLogging (self ._tmp_graph_def , ops = ["Max" ], message = "__max:" ).do_transformation ()
266
+ InsertLogging (self ._tmp_graph_def , ops = ["Max" ],
267
+ message = "__max:" ).do_transformation ()
268
+ # InsertLogging(
269
+ # self._tmp_graph_def,
270
+ # ops=["QuantizedConv2DWithBiasAndRelu",
271
+ # "QuantizedConv2DWithBias"
272
+ # ],
273
+ # message="__KL:",
274
+ # summarize=-1).do_transformation()
275
+
203
276
write_graph (self ._tmp_graph_def , self ._int8_logged_graph )
204
277
self ._tmp_graph_def .CopyFrom (int8_dynamic_range_graph_def )
205
278
206
- def _generate_calibration_data (self ):
279
+ def _generate_calibration_data (self , graph , output , enable_kl_algo = False ):
207
280
cmd = self .gen_calib_data_cmds
208
- cmd = cmd .format (self ._int8_logged_graph )
209
- f = open (self ._requant_min_max_log , 'w' , buffering = 1 )
210
- p = subprocess .Popen (shlex .split (cmd ), stderr = subprocess .STDOUT , stdout = subprocess .PIPE )
211
- try :
212
- for line in p .stdout :
213
- line_str = line .decode (sys .stdout .encoding )
214
- sys .stdout .write (line_str )
215
- f .write (line_str )
216
- p .communicate ()
217
- except Exception :
218
- p .kill ()
219
- p .wait ()
220
- raise
221
- if p .poll ():
222
- raise SystemExit ('ERROR generating calibration data, command: \n {}' .format (cmd ))
223
-
224
- def _freeze_requantization_ranges (self ):
225
- self ._tmp_graph_def = freeze_max (self ._tmp_graph_def , self ._requant_min_max_log )
226
- self ._tmp_graph_def = freeze_min (self ._tmp_graph_def , self ._requant_min_max_log )
227
- self ._tmp_graph_def = freeze_requantization_range (self ._tmp_graph_def , self ._requant_min_max_log )
281
+ cmd = cmd .format (graph )
282
+ p = subprocess .Popen (shlex .split (cmd ),
283
+ stderr = subprocess .STDOUT ,
284
+ stdout = subprocess .PIPE )
285
+ while p .poll () is None :
286
+ line = p .stdout .readline ().strip ().decode ()
287
+ if line and line .startswith (';' ):
288
+ if not enable_kl_algo :
289
+ output .append (line )
290
+
291
+ if enable_kl_algo and line .rsplit (':' )[0 ] in self ._kl_keys :
292
+ fp32_data = get_all_fp32_data (line .rsplit (':' )[- 1 ])
293
+ key = self ._print_node_mapping [line [1 :].split ('__print' )[0 ]] + '_eightbit_requant_range'
294
+ if key not in self ._kl_op_dict :
295
+ self ._kl_op_dict [key ] = get_tensor_histogram (fp32_data )
296
+ else :
297
+ self ._kl_op_dict [key ] = combine_histogram (self ._kl_op_dict [key ], fp32_data )
298
+
299
+ def _freeze_requantization_ranges (self , additional_data = None ):
300
+ use_moving_average = self .algo == "MA"
301
+ self ._tmp_graph_def = freeze_max (self ._tmp_graph_def ,
302
+ self ._calibration_data ,
303
+ use_moving_average )
304
+ self ._tmp_graph_def = freeze_min (self ._tmp_graph_def ,
305
+ self ._calibration_data ,
306
+ use_moving_average )
307
+ self ._tmp_graph_def = freeze_requantization_range (
308
+ self ._tmp_graph_def , self ._calibration_data , use_moving_average ,
309
+ additional_data )
228
310
if self .debug :
229
311
write_graph (self ._tmp_graph_def , self ._int8_frozen_range_graph )
230
312
@@ -256,5 +338,3 @@ def _post_clean(self):
256
338
"""
257
339
if gfile .Exists (self ._int8_logged_graph ):
258
340
os .remove (self ._int8_logged_graph )
259
- if gfile .Exists (self ._requant_min_max_log ):
260
- os .remove (self ._requant_min_max_log )
0 commit comments