@@ -97,22 +97,16 @@ Status FuseQuantizedConvolutionAndRequantize(
97
97
AddNodeInput (const_requantize_range_min_node.name (), &fused_conv);
98
98
AddNodeInput (const_requantize_range_max_node.name (), &fused_conv);
99
99
100
- // Add additional inputs to
101
- // QuantizedConv2DWithBiasSumAndReluAndRequantize
100
+ // Ensure QuantizedConv2DWithBiasSumAndReluAndRequantize receives
101
+ // integer summand. Because requantization fusion is registered
102
+ // for integer summand only.
102
103
if (quantized_conv2D_op_name.compare (
103
104
" QuantizedConv2DWithBiasSumAndRelu" ) == 0 ) {
104
- const NodeDef *in_requantize = node_map[node_map[
105
- quantized_conv2D_node.input (n_input)]->input (0 )];
106
105
const NodeDef *summand_node = node_map[quantized_conv2D_node.input (
107
106
n_input)];
108
- bool quantized_summand = str_util::StrContains (
109
- in_requantize->op (), " Quantized" );
110
- // If the summand is not quantized, we need to quantize it since the
111
- // convolution kernel assumes that the summand is always quanitzed.
112
- if (!quantized_summand &&
113
- !is_perchannel &&
114
- in_requantize->op () != " Requantize" &&
115
- in_requantize->op () != " QuantizeV2" ) {
107
+ NodeDef* new_summand_node = nullptr ;
108
+ NodeDef quantize_node;
109
+ if (summand_node->op () != " Dequantize" ) {
116
110
// Quantizing the summand.
117
111
// Add some common constants we need for reshaping inputs.
118
112
NodeDef reshape_dims;
@@ -156,10 +150,20 @@ Status FuseQuantizedConvolutionAndRequantize(
156
150
AddNodeInput (reshape_node.name (), &max_node);
157
151
AddNodeInput (reduction_dims.name (), &max_node);
158
152
159
- NodeDef quantize_node;
153
+ // NodeDef quantize_node;
160
154
quantize_node.set_op (" QuantizeV2" );
161
155
quantize_node.set_name (summand_node->name () + " /quantize" );
162
- SetNodeAttr (" T" , DT_QUINT8, &quantize_node);
156
+ // Decide data type of quantize op
157
+ std::vector<string> relu_ops = {
158
+ " Relu" ,
159
+ " Relu6"
160
+ };
161
+ bool is_relu = std::find (relu_ops.begin (), relu_ops.end (),
162
+ summand_node->op ()) != relu_ops.end ();
163
+ if (is_relu)
164
+ SetNodeAttr (" T" , DT_QUINT8, &quantize_node);
165
+ else
166
+ SetNodeAttr (" T" , DT_QINT8, &quantize_node);
163
167
SetNodeAttr (" mode" , " SCALED" , &quantize_node);
164
168
165
169
AddNodeInput (summand_node->name (), &reshape_node);
@@ -169,41 +173,71 @@ Status FuseQuantizedConvolutionAndRequantize(
169
173
AddNodeInput (min_node.name (), &quantize_node);
170
174
AddNodeInput (max_node.name (), &quantize_node);
171
175
172
- AddNodeInput (quantize_node.name (), &fused_conv);
173
- AddNodeInput (quantize_node.name () + " :1" , &fused_conv);
174
- AddNodeInput (quantize_node.name () + " :2" , &fused_conv);
175
-
176
176
new_nodes->push_back (reshape_dims);
177
177
new_nodes->push_back (reduction_dims);
178
178
new_nodes->push_back (reshape_node);
179
179
new_nodes->push_back (min_node);
180
180
new_nodes->push_back (max_node);
181
181
new_nodes->push_back (quantize_node);
182
+ // Set the new summand node for fused_conv
183
+ new_summand_node = &quantize_node;
182
184
} else {
183
- string summand (in_requantize->name ());
184
- string min_summand (in_requantize->name () + " :1" );
185
- string max_summand (in_requantize->name () + " :2" );
186
- AddNodeInput (summand, &fused_conv);
187
- AddNodeInput (min_summand, &fused_conv);
188
- AddNodeInput (max_summand, &fused_conv);
185
+ // If summand node is "Dequantize" then either "QuantizeV2" or
186
+ // "Requantize{PerChannel}" is feeding Dequantize op.
187
+ // Set new_summand_node as the input of summand node.
188
+ new_summand_node = const_cast <NodeDef*>(node_map[
189
+ summand_node->input (0 )]);
189
190
}
190
-
191
- // Signed version QuantizedConv2DWithBiasSumAndReluAndRequantize
192
- // if Relu does not follow the convolution operation
193
- std::vector<string> signed_ops = {
194
- " QuantizedConv2DWithBias" ,
195
- " QuantizedConv2D"
196
- };
197
- bool is_signed_summand =
191
+ string summand (new_summand_node->name ());
192
+ string min_summand (new_summand_node->name () + " :1" );
193
+ string max_summand (new_summand_node->name () + " :2" );
194
+ AddNodeInput (summand, &fused_conv);
195
+ AddNodeInput (min_summand, &fused_conv);
196
+ AddNodeInput (max_summand, &fused_conv);
197
+
198
+ DataType summand_type = DT_QUINT8;
199
+ // New summand node should be QuantizeV2 or
200
+ // Requantize{PerChannel}
201
+ if (new_summand_node->op () == " QuantizeV2" ) {
202
+ TF_RETURN_IF_ERROR (GetNodeAttr (*new_summand_node,
203
+ " T" , &summand_type));
204
+ } else if (new_summand_node->op () == " RequantizePerChannel" ) {
205
+ TF_RETURN_IF_ERROR (GetNodeAttr (*new_summand_node,
206
+ " out_type" , &summand_type));
207
+ } else if (new_summand_node->op () == " Requantize" ) {
208
+ // Requantize op is Eigen kernel that does non-SCALED quantization
209
+ // and always maps into quint8. However, for MKLDNN fusion, which is
210
+ // SCALED quantization, the summand fused requantize op may have
211
+ // qint8 or quint8 as its output type. Therefore, it is needed to
212
+ // set the summand_type correctly.
213
+ std::vector<string> signed_ops = {
214
+ " QuantizedConv2DWithBias" ,
215
+ " QuantizedConv2D"
216
+ };
217
+ bool is_signed_summand =
198
218
std::find (signed_ops.begin (), signed_ops.end (),
199
- node_map[in_requantize->input (0 )]->op ()) != signed_ops.end ();
200
- if (is_signed_summand) {
201
- fused_conv.set_op (
202
- " QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" );
203
- SetNodeAttr (" Tsummand" , DT_QINT8, &fused_conv);
219
+ node_map[new_summand_node->input (0 )]->op ())
220
+ != signed_ops.end ();
221
+ summand_type = is_signed_summand ? DT_QINT8 : DT_QUINT8;
222
+ } else if (str_util::StartsWith (new_summand_node->op (),
223
+ " Quantized" )) {
224
+ if (HasNodeAttr (*new_summand_node, " T" )) {
225
+ TF_RETURN_IF_ERROR (GetNodeAttr (*new_summand_node,
226
+ " T" , &summand_type));
227
+ } else if (HasNodeAttr (*new_summand_node, " out_type" )) {
228
+ TF_RETURN_IF_ERROR (GetNodeAttr (*new_summand_node,
229
+ " out_type" , &summand_type));
230
+ }
204
231
} else {
205
- SetNodeAttr (" Tsummand" , DT_QUINT8, &fused_conv);
232
+ return Status (error::Code::FAILED_PRECONDITION,
233
+ " Fusion is not supported, a fix is required." );
206
234
}
235
+ SetNodeAttr (" Tsummand" , summand_type, &fused_conv);
236
+ // Decide whether signed version of
237
+ // QuantizedConv2DWithBiasSumAndReluAndRequantize or not
238
+ if (summand_type == DT_QINT8)
239
+ fused_conv.set_op (
240
+ " QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" );
207
241
}
208
242
209
243
// Add control input to the very end of the input list
@@ -216,32 +250,21 @@ Status FuseQuantizedConvolutionAndRequantize(
216
250
CopyNodeAttr (quantized_conv2D_node, " strides" , " strides" , &fused_conv);
217
251
CopyNodeAttr (quantized_conv2D_node, " padding" , " padding" , &fused_conv);
218
252
219
- if (is_perchannel) {
220
- std::vector<std::string> fused_quantized_bias_ops = {
221
- " QuantizedConv2DWithBias" ,
222
- " QuantizedConv2DWithBiasAndRelu" ,
223
- " QuantizedDepthwiseConv2DWithBias" ,
224
- " QuantizedDepthwiseConv2DWithBiasAndRelu" ,
225
- " QuantizedConv2DWithBiasSumAndRelu" ,
226
- " QuantizedConv2DWithBiasSignedSumAndRelu"
227
- };
228
-
229
- if (std::find (fused_quantized_bias_ops.begin (),
230
- fused_quantized_bias_ops.end (),
231
- quantized_conv2D_node.op ()) != fused_quantized_bias_ops.end ()) {
232
- SetNodeAttr (" Tbias" , DT_FLOAT, &fused_conv);
233
- }
253
+ std::vector<std::string> fused_quantized_bias_ops = {
254
+ " QuantizedConv2DWithBias" ,
255
+ " QuantizedConv2DWithBiasAndRelu" ,
256
+ " QuantizedDepthwiseConv2DWithBias" ,
257
+ " QuantizedDepthwiseConv2DWithBiasAndRelu" ,
258
+ " QuantizedConv2DWithBiasSumAndRelu" ,
259
+ };
260
+ if (std::find (fused_quantized_bias_ops.begin (),
261
+ fused_quantized_bias_ops.end (),
262
+ quantized_conv2D_node.op ()) != fused_quantized_bias_ops.end ()) {
263
+ SetNodeAttr (" Tbias" , DT_FLOAT, &fused_conv);
234
264
}
235
-
236
- CopyNodeAttr (quantized_conv2D_node, " Tinput" , " Tinput" , &fused_conv);
237
- CopyNodeAttr (quantized_conv2D_node, " Tfilter" , " Tfilter" , &fused_conv);
238
- CopyNodeAttr (quantized_conv2D_node, " strides" , " strides" , &fused_conv);
239
- CopyNodeAttr (quantized_conv2D_node, " padding" , " padding" , &fused_conv);
240
-
241
265
if (HasNodeAttr (quantized_conv2D_node, " padding_list" ))
242
266
CopyNodeAttr (quantized_conv2D_node, " padding_list" ,
243
267
" padding_list" , &fused_conv);
244
-
245
268
// Copy dilation attribute if exsit in the orginal node
246
269
if (HasNodeAttr (quantized_conv2D_node, " dilations" ))
247
270
CopyNodeAttr (quantized_conv2D_node, " dilations" ,
@@ -259,93 +282,112 @@ Status FuseQuantizedConvolutionAndRequantize(
259
282
},
260
283
{}, &replaced_graph_def));
261
284
262
- if (!is_perchannel) {
263
- // Convert bias float -> int32 on replaced_graph_def
264
- std::vector<std::string> fused_requantized_bias_ops = {
265
- " QuantizedConv2DWithBiasAndRequantize" ,
266
- " QuantizedConv2DWithBiasAndReluAndRequantize" ,
267
- " QuantizedConv2DWithBiasSumAndReluAndRequantize" ,
268
- " QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
269
- };
270
-
271
- node_map.clear ();
272
- MapNamesToNodes (replaced_graph_def, &node_map);
273
- for (auto & node_pair : node_map) {
274
- const NodeDef *node = node_pair.second ;
275
- if (str_util::StartsWith (node->op (), " Dequantize" )) {
276
- // dequant node should accept DT_QINT8 if the input node is
277
- // "QuantizedConv2DAndRequantize" and
278
- // "QuantizedConv2DWithBiasAndRequantize"
279
- std::string input_node_op =
280
- node_map[NodeNameFromInput (node->input (0 ))]->op ();
281
- if (str_util::StartsWith (input_node_op,
282
- " QuantizedConv2DAndRequantize" ) ||
283
- str_util::StartsWith (input_node_op,
284
- " QuantizedConv2DWithBiasAndRequantize" )) {
285
- SetNodeAttr (" T" , DT_QINT8, const_cast <NodeDef*>(node));
286
- SetNodeAttr (" mode" , " SCALED" , const_cast <NodeDef*>(node));
287
- }
285
+ // After Requantize op fusion, fix attributes for nodes in the graph,
286
+ // if threre is some discrepency. And also quantize the bias (float -> int32)
287
+ // List of requantize fused ops that have biases.
288
+ std::vector<std::string> fused_requantized_bias_ops = {
289
+ " QuantizedConv2DWithBiasAndRequantize" ,
290
+ " QuantizedConv2DWithBiasAndReluAndRequantize" ,
291
+ " QuantizedConv2DWithBiasSumAndReluAndRequantize" ,
292
+ " QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" ,
293
+ " QuantizedDepthwiseConv2DWithBiasAndRequantize" ,
294
+ " QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"
295
+ };
296
+
297
+ node_map.clear ();
298
+ MapNamesToNodes (replaced_graph_def, &node_map);
299
+ for (auto & node_pair : node_map) {
300
+ const NodeDef *node = node_pair.second ;
301
+ // An workaround to fix attributes of "Dequantize" op with non-perchannel
302
+ // quantization. "Dequantize" node should accept DT_QINT8 if the input node
303
+ // is "QuantizedConv2DAndRequantize" or
304
+ // "QuantizedConv2DWithBiasAndRequantize".
305
+ if (str_util::StartsWith (node->op (), " Dequantize" )) {
306
+ std::string input_node_op =
307
+ node_map[NodeNameFromInput (node->input (0 ))]->op ();
308
+ if (str_util::StartsWith (input_node_op,
309
+ " QuantizedConv2DAndRequantize" ) ||
310
+ str_util::StartsWith (input_node_op,
311
+ " QuantizedConv2DWithBiasAndRequantize" )) {
312
+ SetNodeAttr (" T" , DT_QINT8, const_cast <NodeDef*>(node));
313
+ SetNodeAttr (" mode" , " SCALED" , const_cast <NodeDef*>(node));
314
+ }
288
315
continue ;
289
316
}
290
-
317
+ // Quantize bias to int32 if input min-max values are constants.
318
+ // This is guaranteed if the preceeding op is a fused requantize op.
291
319
bool is_fused_requantized_conv_op =
292
- std::find (fused_requantized_bias_ops.begin (),
293
- fused_requantized_bias_ops.end (), node->op ())
294
- != fused_requantized_bias_ops.end ();
295
- if (is_fused_requantized_conv_op) {
296
- // If the op is feed by Quantize op then we keep bias as float
297
- std::string input_op = node_map[ NodeNameFromInput (
298
- node-> input ( 0 ))]-> op ();
299
- if ( str_util::StartsWith (input_op , " QuantizedConv2D " ) &&
300
- str_util::EndsWith (input_op , " AndRequantize" )) {
301
- NodeDef *bias_node = const_cast <NodeDef*>(node_map[NodeNameFromInput (
302
- node->input (2 ))]);
303
- const NodeDef *min_input_node = node_map[NodeNameFromInput (
320
+ std::find (fused_requantized_bias_ops.begin (),
321
+ fused_requantized_bias_ops.end (), node->op ())
322
+ != fused_requantized_bias_ops.end ();
323
+ if (is_fused_requantized_conv_op) {
324
+ std::string preceeding_op = node_map[ NodeNameFromInput (
325
+ node-> input ( 0 ))]-> op ();
326
+ if ( str_util::StartsWith (preceeding_op, " Quantized " ) &&
327
+ str_util::StrContains (preceeding_op , " Conv2D " ) &&
328
+ str_util::EndsWith (preceeding_op , " AndRequantize" )) {
329
+ NodeDef *bias_node = const_cast <NodeDef*>(node_map[NodeNameFromInput (
330
+ node->input (2 ))]);
331
+ const NodeDef *min_input_node = node_map[NodeNameFromInput (
304
332
node_map[node->input (0 )]->input (7 ))];
305
- const NodeDef *max_input_node = node_map[NodeNameFromInput (
333
+ const NodeDef *max_input_node = node_map[NodeNameFromInput (
306
334
node_map[node->input (0 )]->input (8 ))];
307
- const NodeDef *min_filter_node = node_map[NodeNameFromInput (
335
+ const NodeDef *min_filter_node = node_map[NodeNameFromInput (
308
336
node->input (5 ))];
309
- const NodeDef *max_filter_node = node_map[NodeNameFromInput (
337
+ const NodeDef *max_filter_node = node_map[NodeNameFromInput (
310
338
node->input (6 ))];
311
- const float min_input =
339
+ const float min_input =
312
340
GetNodeTensorAttr (*min_input_node, " value" ).flat <float >()(0 );
313
- const float max_input =
341
+ const float max_input =
314
342
GetNodeTensorAttr (*max_input_node, " value" ).flat <float >()(0 );
315
- const float min_filter =
316
- GetNodeTensorAttr (*min_filter_node, " value" ).flat <float >()(0 );
317
- const float max_filter =
318
- GetNodeTensorAttr (*max_filter_node, " value" ).flat <float >()(0 );
319
-
320
- TensorProto float_tensor_proto =
321
- bias_node->attr ().at (" value" ).tensor ();
322
- Tensor float_tensor;
323
- CHECK (float_tensor.FromProto (float_tensor_proto));
324
- CHECK_EQ (float_tensor.dtype (), DT_FLOAT);
325
- float *p_bias_float = float_tensor.flat <float >().data ();
326
-
327
- Tensor int32_tensor = Tensor (DT_QINT32, float_tensor.shape ());
328
- qint32 *p_bias_int32 = int32_tensor.flat <qint32>().data ();
329
-
330
- float bias_scale = 255.0 * 127.0 /
343
+ const Tensor& min_filter_tensor =
344
+ GetNodeTensorAttr (*min_filter_node, " value" );
345
+ const Tensor& max_filter_tensor =
346
+ GetNodeTensorAttr (*max_filter_node, " value" );
347
+ const float * min_filter = min_filter_tensor.flat <float >().data ();
348
+ const float * max_filter = max_filter_tensor.flat <float >().data ();
349
+ size_t num_scale_factors = min_filter_tensor.NumElements ();
350
+
351
+ TensorProto float_tensor_proto =
352
+ bias_node->attr ().at (" value" ).tensor ();
353
+ Tensor float_bias_tensor;
354
+ CHECK (float_bias_tensor.FromProto (float_tensor_proto));
355
+ CHECK_EQ (float_bias_tensor.dtype (), DT_FLOAT);
356
+ float *float_bias = float_bias_tensor.flat <float >().data ();
357
+
358
+ Tensor int32_bias_tensor = Tensor (DT_QINT32, float_bias_tensor.shape ());
359
+ qint32 *int32_bias = int32_bias_tensor.flat <qint32>().data ();
360
+ std::vector<float > scales (num_scale_factors);
361
+ for (size_t i = 0 ; i < num_scale_factors; ++i) {
362
+ scales[i] = 255.0 * 127.0 /
331
363
(std::max (std::abs (max_input), std::abs (min_input)) *
332
- std::max (std::abs (max_filter), std::abs (min_filter)));
333
- int64 nelems = float_tensor.NumElements ();
334
- for (int64 n = 0 ; n < nelems; n++)
335
- p_bias_int32[n] = (int32_t ) (p_bias_float[n] * bias_scale);
336
-
337
- bias_node->clear_attr ();
338
- AttrValue attr_type;
339
- attr_type.set_type (int32_tensor.dtype ());
340
- bias_node->mutable_attr ()->insert ({" dtype" , attr_type});
341
- AttrValue attr_tensor;
342
- TensorProto* t = attr_tensor.mutable_tensor ();
343
- int32_tensor.AsProtoTensorContent (t);
344
- bias_node->mutable_attr ()->insert ({" value" , attr_tensor});
345
- SetNodeAttr (" Tbias" , DT_QINT32, const_cast <NodeDef*>(node));
364
+ std::max (std::abs (max_filter[i]), std::abs (min_filter[i])));
365
+ }
366
+ int64 bias_length = float_bias_tensor.NumElements ();
367
+ if (num_scale_factors > 1 ) {
368
+ if (bias_length != num_scale_factors) {
369
+ return Status (error::Code::FAILED_PRECONDITION,
370
+ " Number of filter output channels is not"
371
+ " equal to bias size" );
372
+ } else {
373
+ for (int64 i = 0 ; i < bias_length; i++)
374
+ int32_bias[i] = (int32_t ) (float_bias[i] * scales[i]);
375
+ }
346
376
} else {
347
- SetNodeAttr (" Tbias" , DT_FLOAT, const_cast <NodeDef*>(node));
377
+ for (int64 i = 0 ; i < bias_length; i++)
378
+ int32_bias[i] = (int32_t ) (float_bias[i] * scales[0 ]);
348
379
}
380
+ bias_node->clear_attr ();
381
+ AttrValue attr_type;
382
+ attr_type.set_type (int32_bias_tensor.dtype ());
383
+ bias_node->mutable_attr ()->insert ({" dtype" , attr_type});
384
+ AttrValue attr_tensor;
385
+ TensorProto* t = attr_tensor.mutable_tensor ();
386
+ int32_bias_tensor.AsProtoTensorContent (t);
387
+ bias_node->mutable_attr ()->insert ({" value" , attr_tensor});
388
+ SetNodeAttr (" Tbias" , DT_QINT32, const_cast <NodeDef*>(node));
389
+ } else {
390
+ SetNodeAttr (" Tbias" , DT_FLOAT, const_cast <NodeDef*>(node));
349
391
}
350
392
}
351
393
}
0 commit comments