Skip to content
This repository was archived by the owner on May 11, 2024. It is now read-only.

Commit 00d591f

Browse files
committed
Merge branch 'gm_r1.1' into 'master'
Merge r1.1 code into master branch See merge request intelai/tools!42
2 parents d88dede + 4c5d8ff commit 00d591f

22 files changed

+1120
-316
lines changed

api/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ of specific models in Model Zoo as examples.
4646

4747
```bash
4848
$ cd ~
49-
$ git clone https://github.com/IntelAI/tools.git quantization && cd quantization
49+
$ git clone https://github.com/IntelAI/tools.git quantization && cd quantization
5050
```
5151

5252
## Step-by-step Procedure for ResNet-50 Quantization

api/examples/quantize_cmd.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ def main(_):
4141
outputs = []
4242

4343
if args.excluded_ops:
44-
excluded_ops = args.exclude_ops.split(',')
44+
excluded_ops = args.excluded_ops.split(',')
4545
else:
4646
excluded_ops = []
4747

4848
if args.excluded_nodes:
49-
excluded_nodes = args.exclude_nodes.split(',')
49+
excluded_nodes = args.excluded_nodes.split(',')
5050
else:
5151
excluded_nodes = []
5252

@@ -61,7 +61,7 @@ def main(_):
6161
callback_cmd = prefix + 'input_graph={} ' + postfix
6262
else:
6363
callback_cmd = args.callback
64-
qt.gen_calib_data_cmds = args.callback
64+
qt.gen_calib_data_cmds = callback_cmd
6565
qt.convert()
6666

6767

api/intel_quantization/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#
18+
__version__="1.1"

api/intel_quantization/graph_converter.py

+121-41
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,22 @@
2525
# from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
2626

2727
# from intel_quantization.quantize_graph import GraphRewriter
28-
from intel_quantization.transform_graph.strip_unused import StripUnusedNodes
29-
from intel_quantization.transform_graph.fold_batch_norm import FoldBatchNormNodes
30-
from intel_quantization.transform_graph.insert_logging import InsertLogging
31-
from intel_quantization.transform_graph.freeze_max_min import freeze_max
32-
from intel_quantization.transform_graph.freeze_max_min import freeze_min
33-
from intel_quantization.transform_graph.freeze_max_min import freeze_requantization_range
34-
from intel_quantization.transform_graph.fuse_quantized_conv_and_requantize import fuse_quantized_conv_and_requantize
35-
from intel_quantization.transform_graph.fuse_column_wise_mul import FuseColumnWiseMul
36-
from intel_quantization.transform_graph.rerange_quantized_concat import RerangeQuantizedConcat
37-
from intel_quantization.util import read_graph, write_graph
38-
from intel_quantization.quantize_graph.quantize_graph_for_intel_cpu import QuantizeGraphForIntel
28+
from .transform_graph.strip_unused import StripUnusedNodes
29+
from .transform_graph.fold_batch_norm import FoldBatchNormNodes
30+
from .transform_graph.insert_logging import InsertLogging
31+
from .transform_graph.freeze_max_min import freeze_max
32+
from .transform_graph.freeze_max_min import freeze_min
33+
from .transform_graph.freeze_max_min import freeze_requantization_range
34+
from .transform_graph.freeze_max_min import get_all_fp32_data, get_tensor_histogram, combine_histogram
35+
from .transform_graph.fuse_quantized_conv_and_requantize import fuse_quantized_conv_and_requantize
36+
from .transform_graph.fuse_column_wise_mul import FuseColumnWiseMul
37+
from .transform_graph.rerange_quantized_concat import RerangeQuantizedConcat
38+
from .util import read_graph, write_graph
39+
from .quantize_graph.quantize_graph_for_intel_cpu import QuantizeGraphForIntel
40+
from .quantize_graph.quantize_graph_common import QuantizeGraphHelper
3941
import os
4042
import shlex
4143
import subprocess
42-
import sys
4344
import logging
4445

4546
logging.getLogger().setLevel(level=logging.INFO)
@@ -52,7 +53,7 @@
5253

5354
class GraphConverter:
5455
def __init__(self, input_graph, output_graph, inputs=[], outputs=[], excluded_ops=[], excluded_nodes=[],
55-
per_channel=False, input_graph_is_binary=True):
56+
per_channel=False, input_graph_is_binary=True, algo='DIRECT'):
5657
"""Convert graph.
5758
5859
:param input_graph: input graph pb file.
@@ -73,13 +74,18 @@ def __init__(self, input_graph, output_graph, inputs=[], outputs=[], excluded_op
7374
self.per_channel = per_channel
7475
self.excluded_ops = excluded_ops
7576
self.excluded_nodes = excluded_nodes
77+
self.algo = algo
7678
self._low_precision_mode = 'eightbit'
77-
79+
self._calibration_data = []
80+
self._fp32_print_data = []
7881
self.gen_calib_data_cmds = None
7982
self.debug = False
8083
self._check_tf_version()
8184
self._check_args()
8285
self._gen_tmp_filenames()
86+
self._kl_op_dict = {}
87+
self._kl_keys = []
88+
self._print_node_mapping = {}
8389

8490
def _check_tf_version(self):
8591
is_supported_version = False
@@ -113,7 +119,7 @@ def _gen_tmp_filenames(self):
113119
self._fp32_optimized_graph = os.path.join(self._output_path, 'fp32_optimized_graph.pb')
114120
self._int8_dynamic_range_graph = os.path.join(self._output_path, 'int8_dynamic_range_graph.pb')
115121
self._int8_logged_graph = os.path.join(self._output_path, 'int8_logged_graph.pb')
116-
self._requant_min_max_log = os.path.join(self._output_path, 'requant_min_max_log.txt')
122+
self._fp32_logged_graph = os.path.join(self._output_path, 'fp32_logged_graph.pb')
117123
self._int8_frozen_range_graph = os.path.join(self._output_path, 'int8_frozen_range_graph.pb')
118124
if not self.output_graph:
119125
self.output_graph = os.path.join(self._output_path, 'int8_final_fused_graph.pb')
@@ -137,6 +143,58 @@ def convert(self):
137143
else:
138144
self.quantize()
139145

146+
def _get_fp32_print_node_names(self):
147+
offset_map = {
148+
"QuantizedConv2DWithBiasSumAndRelu": 3,
149+
"QuantizedConv2DWithBiasAndRelu": 2,
150+
"QuantizedConv2DWithBias": 1,
151+
}
152+
target_conv_op = []
153+
sorted_graph = QuantizeGraphHelper().get_sorted_graph(
154+
self._fp32_origin_graph, self.outputs)
155+
156+
node_name_mapping = {
157+
node.name: node
158+
for node in self._tmp_graph_def.node if node.op != "Const"
159+
}
160+
161+
for node in self._tmp_graph_def.node:
162+
if node.op in offset_map:
163+
target_conv_op.append(node.name.split('_eightbit_')[0])
164+
fp32_node_name_mapping = {
165+
node.name: node
166+
for node in sorted_graph.node if node.op != "Const"
167+
}
168+
sorted_node_names = [i.name for i in sorted_graph.node if i.op != "Const"]
169+
170+
output_node_names = []
171+
for i in target_conv_op:
172+
if node_name_mapping[
173+
i + "_eightbit_quantized_conv"].op == 'QuantizedConv2DWithBiasSumAndRelu':
174+
start_index = sorted_node_names.index(i)
175+
for index, value in enumerate(sorted_node_names[start_index:]):
176+
if fp32_node_name_mapping[value].op.startswith(
177+
"Add") and fp32_node_name_mapping[
178+
sorted_node_names[start_index + index + 1]].op == "Relu":
179+
output_node_names.append(
180+
sorted_node_names[start_index + index + 1])
181+
self._print_node_mapping[sorted_node_names[start_index + index + 1]] = i
182+
elif i in sorted_node_names:
183+
start_index = sorted_node_names.index(i)
184+
end_index = start_index + offset_map[node_name_mapping[
185+
i + "_eightbit_quantized_conv"].op]
186+
output_node_names.append(sorted_node_names[end_index])
187+
self._print_node_mapping[sorted_node_names[end_index]] = i
188+
189+
for i in output_node_names:
190+
self._kl_keys.append(';' + i + '__print__;__KL')
191+
192+
InsertLogging(self._fp32_origin_graph,
193+
node_name_list=output_node_names,
194+
message="__KL:",
195+
summarize=-1, dump_fp32=True).do_transformation()
196+
write_graph(self._fp32_origin_graph, self._fp32_logged_graph)
197+
140198
def quantize(self):
141199
"""Quantize graph only (without optimizing fp32 graph), including:
142200
1) quantize graph,
@@ -150,9 +208,14 @@ def quantize(self):
150208
'to generate calibration data.')
151209
try:
152210
self._quantize_graph()
211+
if self.algo == "KL":
212+
self._get_fp32_print_node_names()
213+
self._generate_calibration_data(self._fp32_logged_graph,
214+
self._fp32_print_data, True)
215+
153216
self._insert_logging()
154-
self._generate_calibration_data()
155-
self._freeze_requantization_ranges()
217+
self._generate_calibration_data(self._int8_logged_graph, self._calibration_data)
218+
self._freeze_requantization_ranges(self._kl_op_dict)
156219
self._fuse_requantize_with_fused_quantized_conv()
157220
except Exception as e:
158221
logging.error('Failed to quantize graph due to: %s', str(e))
@@ -172,6 +235,7 @@ def _optimize_frozen_fp32_graph(self):
172235
self._tmp_graph_def = graph_util.remove_training_nodes(self._tmp_graph_def, self.outputs)
173236
self._tmp_graph_def = FoldBatchNormNodes(self._tmp_graph_def).do_transform()
174237
write_graph(self._tmp_graph_def, self._fp32_optimized_graph)
238+
self._fp32_origin_graph = self._tmp_graph_def
175239

176240
def _quantize_graph(self):
177241
"""quantize graph."""
@@ -199,32 +263,50 @@ def _insert_logging(self):
199263
ops=["RequantizationRange{}".format("PerChannel" if self.per_channel else "")],
200264
message="__requant_min_max:").do_transformation()
201265
InsertLogging(self._tmp_graph_def, ops=["Min"], message="__min:").do_transformation()
202-
InsertLogging(self._tmp_graph_def, ops=["Max"], message="__max:").do_transformation()
266+
InsertLogging(self._tmp_graph_def, ops=["Max"],
267+
message="__max:").do_transformation()
268+
# InsertLogging(
269+
# self._tmp_graph_def,
270+
# ops=["QuantizedConv2DWithBiasAndRelu",
271+
# "QuantizedConv2DWithBias"
272+
# ],
273+
# message="__KL:",
274+
# summarize=-1).do_transformation()
275+
203276
write_graph(self._tmp_graph_def, self._int8_logged_graph)
204277
self._tmp_graph_def.CopyFrom(int8_dynamic_range_graph_def)
205278

206-
def _generate_calibration_data(self):
279+
def _generate_calibration_data(self, graph, output, enable_kl_algo=False):
207280
cmd = self.gen_calib_data_cmds
208-
cmd = cmd.format(self._int8_logged_graph)
209-
f = open(self._requant_min_max_log, 'w', buffering=1)
210-
p = subprocess.Popen(shlex.split(cmd), stderr=subprocess.STDOUT, stdout=subprocess.PIPE)
211-
try:
212-
for line in p.stdout:
213-
line_str = line.decode(sys.stdout.encoding)
214-
sys.stdout.write(line_str)
215-
f.write(line_str)
216-
p.communicate()
217-
except Exception:
218-
p.kill()
219-
p.wait()
220-
raise
221-
if p.poll():
222-
raise SystemExit('ERROR generating calibration data, command: \n{}'.format(cmd))
223-
224-
def _freeze_requantization_ranges(self):
225-
self._tmp_graph_def = freeze_max(self._tmp_graph_def, self._requant_min_max_log)
226-
self._tmp_graph_def = freeze_min(self._tmp_graph_def, self._requant_min_max_log)
227-
self._tmp_graph_def = freeze_requantization_range(self._tmp_graph_def, self._requant_min_max_log)
281+
cmd = cmd.format(graph)
282+
p = subprocess.Popen(shlex.split(cmd),
283+
stderr=subprocess.STDOUT,
284+
stdout=subprocess.PIPE)
285+
while p.poll() is None:
286+
line = p.stdout.readline().strip().decode()
287+
if line and line.startswith(';'):
288+
if not enable_kl_algo:
289+
output.append(line)
290+
291+
if enable_kl_algo and line.rsplit(':')[0] in self._kl_keys:
292+
fp32_data = get_all_fp32_data(line.rsplit(':')[-1])
293+
key = self._print_node_mapping[line[1:].split('__print')[0]] + '_eightbit_requant_range'
294+
if key not in self._kl_op_dict:
295+
self._kl_op_dict[key] = get_tensor_histogram(fp32_data)
296+
else:
297+
self._kl_op_dict[key] = combine_histogram(self._kl_op_dict[key], fp32_data)
298+
299+
def _freeze_requantization_ranges(self, additional_data=None):
300+
use_moving_average = self.algo == "MA"
301+
self._tmp_graph_def = freeze_max(self._tmp_graph_def,
302+
self._calibration_data,
303+
use_moving_average)
304+
self._tmp_graph_def = freeze_min(self._tmp_graph_def,
305+
self._calibration_data,
306+
use_moving_average)
307+
self._tmp_graph_def = freeze_requantization_range(
308+
self._tmp_graph_def, self._calibration_data, use_moving_average,
309+
additional_data)
228310
if self.debug:
229311
write_graph(self._tmp_graph_def, self._int8_frozen_range_graph)
230312

@@ -256,5 +338,3 @@ def _post_clean(self):
256338
"""
257339
if gfile.Exists(self._int8_logged_graph):
258340
os.remove(self._int8_logged_graph)
259-
if gfile.Exists(self._requant_min_max_log):
260-
os.remove(self._requant_min_max_log)

api/intel_quantization/quantize_graph/quantize_graph_concatv2.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
from tensorflow.python.framework import dtypes
33
from tensorflow.core.framework import node_def_pb2
4-
from intel_quantization.quantize_graph.quantize_graph_base import QuantizeNodeBase
5-
from intel_quantization.quantize_graph.quantize_graph_common import QuantizeGraphHelper as helper
4+
from .quantize_graph_base import QuantizeNodeBase
5+
from .quantize_graph_common import QuantizeGraphHelper as helper
66

77
import re
88

@@ -53,9 +53,15 @@ def _apply_concatv2_transform(self, original_node):
5353
self._add_dequantize_result_node(quantized_concat_name,
5454
original_node.name)
5555

56+
def _quantizable_concat(self, node):
57+
for input_node_name in node.input[:node.attr['N'].i]:
58+
if self.node_name_mapping[helper.node_name_from_input(input_node_name)].node.op != "Dequantize":
59+
return False
60+
return True
61+
5662
def _apply_concatv2_quantization(self):
5763
for _, v in self.node_name_mapping.items():
58-
if v.node.op in ("ConcatV2") and not re.search(
64+
if v.node.op in ("ConcatV2") and self._quantizable_concat(v.node) and not re.search(
5965
r'map(_\d+)?/while', v.node.name) and dtypes.as_dtype(
6066
v.node.attr["T"].type) == dtypes.float32:
6167
self._apply_concatv2_transform(v.node)

api/intel_quantization/quantize_graph/quantize_graph_conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from tensorflow.core.framework import node_def_pb2
44
from tensorflow.python.framework import dtypes
55

6-
from intel_quantization.quantize_graph.quantize_graph_common import QuantizeGraphHelper as helper
7-
from intel_quantization.quantize_graph.quantize_graph_base import QuantizeNodeBase
6+
from .quantize_graph_common import QuantizeGraphHelper as helper
7+
from .quantize_graph_base import QuantizeNodeBase
88

99
import logging
1010

@@ -232,7 +232,7 @@ def apply_conv_biasadd_fusion(self, match_node_name):
232232
helper.set_attr_dtype(quantized_conv_node, "out_type",
233233
dtypes.qint32)
234234
self.add_output_graph_node(quantized_conv_node)
235-
requantize_type = dtypes.qint8 if self.per_channel else dtypes.quint8
235+
requantize_type = dtypes.qint8
236236

237237
quantize_down_name = self._add_quantize_down_nodes(
238238
node, quantized_node_name, requantize_type, False)

api/intel_quantization/quantize_graph/quantize_graph_for_intel_cpu.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from tensorflow.python.platform import gfile
44
from tensorflow.python.framework import graph_util
55

6-
from intel_quantization.quantize_graph.quantize_graph_base import QuantizeGraphBase
7-
from intel_quantization.quantize_graph.quantize_graph_common import QuantizeGraphHelper
8-
from intel_quantization.quantize_graph.quantize_graph_conv import FuseNodeStartWithConv2d
9-
from intel_quantization.quantize_graph.quantize_graph_concatv2 import FuseNodeStartWithConcatV2
10-
from intel_quantization.quantize_graph.quantize_graph_matmul import FuseNodeStartWithMatmul
11-
from intel_quantization.quantize_graph.quantize_graph_pooling import FuseNodeStartWithPooling
12-
from intel_quantization.quantize_graph.quantize_graph_pad import FuseNodeStartWithPad
6+
from .quantize_graph_base import QuantizeGraphBase
7+
from .quantize_graph_common import QuantizeGraphHelper
8+
from .quantize_graph_conv import FuseNodeStartWithConv2d
9+
from .quantize_graph_concatv2 import FuseNodeStartWithConcatV2
10+
from .quantize_graph_matmul import FuseNodeStartWithMatmul
11+
from .quantize_graph_pooling import FuseNodeStartWithPooling
12+
from .quantize_graph_pad import FuseNodeStartWithPad
1313

1414

1515
class QuantizeGraphForIntel(QuantizeGraphBase):

api/intel_quantization/quantize_graph/quantize_graph_matmul.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from tensorflow.core.framework import node_def_pb2
44
from tensorflow.python.framework import dtypes
55

6-
from intel_quantization.quantize_graph.quantize_graph_common import QuantizeGraphHelper as helper
7-
from intel_quantization.quantize_graph.quantize_graph_base import QuantizeNodeBase
6+
from .quantize_graph_common import QuantizeGraphHelper as helper
7+
from .quantize_graph_base import QuantizeNodeBase
88

99
import logging
1010

api/intel_quantization/quantize_graph/quantize_graph_pad.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from tensorflow.core.framework import node_def_pb2
44
from tensorflow.python.framework import tensor_util
55

6-
from intel_quantization.quantize_graph.quantize_graph_base import QuantizeNodeBase
7-
from intel_quantization.quantize_graph.quantize_graph_common import QuantizeGraphHelper as helper
6+
from .quantize_graph_base import QuantizeNodeBase
7+
from .quantize_graph_common import QuantizeGraphHelper as helper
88

99

1010
class FuseNodeStartWithPad(QuantizeNodeBase):

api/intel_quantization/quantize_graph/quantize_graph_pooling.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from tensorflow.core.framework import node_def_pb2
33
from tensorflow.python.framework import dtypes
44

5-
from intel_quantization.quantize_graph.quantize_graph_base import QuantizeNodeBase
6-
from intel_quantization.quantize_graph.quantize_graph_common import QuantizeGraphHelper as helper
5+
from .quantize_graph_base import QuantizeNodeBase
6+
from .quantize_graph_common import QuantizeGraphHelper as helper
77

88

99
class FuseNodeStartWithPooling(QuantizeNodeBase):
@@ -24,7 +24,8 @@ def _add_pool_function(self, original_node, quantized_op_node):
2424

2525
def _apply_pool_quantization(self):
2626
for _, v in self.node_name_mapping.items():
27-
if v.node.op in ("AvgPool", "MaxPool"):
27+
if v.node.op in ("AvgPool", "MaxPool") and self._find_relu_node(
28+
v.node):
2829
self.eightbitize_single_input_tensor_node(
2930
v.node, self._add_pool_function)
3031
else:

0 commit comments

Comments
 (0)