Skip to content
This repository was archived by the owner on May 11, 2024. It is now read-only.

Commit 7361a87

Browse files
mdfaijulkarthikvadla
authored andcommitted
[Urgent for mlperf] Fixed issues and cleaned up fuse_quantized_convolution.cc (#75)
1 parent 90433ef commit 7361a87

File tree

1 file changed

+177
-135
lines changed

1 file changed

+177
-135
lines changed

tensorflow_quantization/graph_transforms/fuse_quantized_convolution.cc

+177-135
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,16 @@ Status FuseQuantizedConvolutionAndRequantize(
9797
AddNodeInput(const_requantize_range_min_node.name(), &fused_conv);
9898
AddNodeInput(const_requantize_range_max_node.name(), &fused_conv);
9999

100-
// Add additional inputs to
101-
// QuantizedConv2DWithBiasSumAndReluAndRequantize
100+
// Ensure QuantizedConv2DWithBiasSumAndReluAndRequantize receives
101+
// integer summand. Because requantization fusion is registered
102+
// for integer summand only.
102103
if (quantized_conv2D_op_name.compare(
103104
"QuantizedConv2DWithBiasSumAndRelu") == 0) {
104-
const NodeDef *in_requantize = node_map[node_map[
105-
quantized_conv2D_node.input(n_input)]->input(0)];
106105
const NodeDef *summand_node = node_map[quantized_conv2D_node.input(
107106
n_input)];
108-
bool quantized_summand = str_util::StrContains(
109-
in_requantize->op(), "Quantized");
110-
// If the summand is not quantized, we need to quantize it since the
111-
// convolution kernel assumes that the summand is always quanitzed.
112-
if (!quantized_summand &&
113-
!is_perchannel &&
114-
in_requantize->op() != "Requantize" &&
115-
in_requantize->op() != "QuantizeV2") {
107+
NodeDef* new_summand_node = nullptr;
108+
NodeDef quantize_node;
109+
if (summand_node->op() != "Dequantize") {
116110
// Quantizing the summand.
117111
// Add some common constants we need for reshaping inputs.
118112
NodeDef reshape_dims;
@@ -156,10 +150,20 @@ Status FuseQuantizedConvolutionAndRequantize(
156150
AddNodeInput(reshape_node.name(), &max_node);
157151
AddNodeInput(reduction_dims.name(), &max_node);
158152

159-
NodeDef quantize_node;
153+
// NodeDef quantize_node;
160154
quantize_node.set_op("QuantizeV2");
161155
quantize_node.set_name(summand_node->name() + "/quantize");
162-
SetNodeAttr("T", DT_QUINT8, &quantize_node);
156+
// Decide data type of quantize op
157+
std::vector<string> relu_ops = {
158+
"Relu",
159+
"Relu6"
160+
};
161+
bool is_relu = std::find(relu_ops.begin(), relu_ops.end(),
162+
summand_node->op()) != relu_ops.end();
163+
if (is_relu)
164+
SetNodeAttr("T", DT_QUINT8, &quantize_node);
165+
else
166+
SetNodeAttr("T", DT_QINT8, &quantize_node);
163167
SetNodeAttr("mode", "SCALED", &quantize_node);
164168

165169
AddNodeInput(summand_node->name(), &reshape_node);
@@ -169,41 +173,71 @@ Status FuseQuantizedConvolutionAndRequantize(
169173
AddNodeInput(min_node.name(), &quantize_node);
170174
AddNodeInput(max_node.name(), &quantize_node);
171175

172-
AddNodeInput(quantize_node.name(), &fused_conv);
173-
AddNodeInput(quantize_node.name() + ":1", &fused_conv);
174-
AddNodeInput(quantize_node.name() + ":2", &fused_conv);
175-
176176
new_nodes->push_back(reshape_dims);
177177
new_nodes->push_back(reduction_dims);
178178
new_nodes->push_back(reshape_node);
179179
new_nodes->push_back(min_node);
180180
new_nodes->push_back(max_node);
181181
new_nodes->push_back(quantize_node);
182+
// Set the new summand node for fused_conv
183+
new_summand_node = &quantize_node;
182184
} else {
183-
string summand(in_requantize->name());
184-
string min_summand(in_requantize->name() + ":1");
185-
string max_summand(in_requantize->name() + ":2");
186-
AddNodeInput(summand, &fused_conv);
187-
AddNodeInput(min_summand, &fused_conv);
188-
AddNodeInput(max_summand, &fused_conv);
185+
// If summand node is "Dequantize" then either "QuantizeV2" or
186+
// "Requantize{PerChannel}" is feeding Dequantize op.
187+
// Set new_summand_node as the input of summand node.
188+
new_summand_node = const_cast<NodeDef*>(node_map[
189+
summand_node->input(0)]);
189190
}
190-
191-
// Signed version QuantizedConv2DWithBiasSumAndReluAndRequantize
192-
// if Relu does not follow the convolution operation
193-
std::vector<string> signed_ops = {
194-
"QuantizedConv2DWithBias",
195-
"QuantizedConv2D"
196-
};
197-
bool is_signed_summand =
191+
string summand(new_summand_node->name());
192+
string min_summand(new_summand_node->name() + ":1");
193+
string max_summand(new_summand_node->name() + ":2");
194+
AddNodeInput(summand, &fused_conv);
195+
AddNodeInput(min_summand, &fused_conv);
196+
AddNodeInput(max_summand, &fused_conv);
197+
198+
DataType summand_type = DT_QUINT8;
199+
// New summand node should be QuantizeV2 or
200+
// Requantize{PerChannel}
201+
if (new_summand_node->op() == "QuantizeV2") {
202+
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
203+
"T", &summand_type));
204+
} else if (new_summand_node->op() == "RequantizePerChannel") {
205+
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
206+
"out_type", &summand_type));
207+
} else if (new_summand_node->op() == "Requantize") {
208+
// Requantize op is Eigen kernel that does non-SCALED quantization
209+
// and always maps into quint8. However, for MKLDNN fusion, which is
210+
// SCALED quantization, the summand fused requantize op may have
211+
// qint8 or quint8 as its output type. Therefore, it is needed to
212+
// set the summand_type correctly.
213+
std::vector<string> signed_ops = {
214+
"QuantizedConv2DWithBias",
215+
"QuantizedConv2D"
216+
};
217+
bool is_signed_summand =
198218
std::find(signed_ops.begin(), signed_ops.end(),
199-
node_map[in_requantize->input(0)]->op()) != signed_ops.end();
200-
if (is_signed_summand) {
201-
fused_conv.set_op(
202-
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize");
203-
SetNodeAttr("Tsummand", DT_QINT8, &fused_conv);
219+
node_map[new_summand_node->input(0)]->op())
220+
!= signed_ops.end();
221+
summand_type = is_signed_summand ? DT_QINT8 : DT_QUINT8;
222+
} else if (str_util::StartsWith(new_summand_node->op(),
223+
"Quantized")) {
224+
if (HasNodeAttr(*new_summand_node, "T")) {
225+
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
226+
"T", &summand_type));
227+
} else if (HasNodeAttr(*new_summand_node, "out_type")) {
228+
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
229+
"out_type", &summand_type));
230+
}
204231
} else {
205-
SetNodeAttr("Tsummand", DT_QUINT8, &fused_conv);
232+
return Status(error::Code::FAILED_PRECONDITION,
233+
"Fusion is not supported, a fix is required.");
206234
}
235+
SetNodeAttr("Tsummand", summand_type, &fused_conv);
236+
// Decide whether signed version of
237+
// QuantizedConv2DWithBiasSumAndReluAndRequantize or not
238+
if (summand_type == DT_QINT8)
239+
fused_conv.set_op(
240+
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize");
207241
}
208242

209243
// Add control input to the very end of the input list
@@ -216,32 +250,21 @@ Status FuseQuantizedConvolutionAndRequantize(
216250
CopyNodeAttr(quantized_conv2D_node, "strides", "strides", &fused_conv);
217251
CopyNodeAttr(quantized_conv2D_node, "padding", "padding", &fused_conv);
218252

219-
if (is_perchannel) {
220-
std::vector<std::string> fused_quantized_bias_ops = {
221-
"QuantizedConv2DWithBias",
222-
"QuantizedConv2DWithBiasAndRelu",
223-
"QuantizedDepthwiseConv2DWithBias",
224-
"QuantizedDepthwiseConv2DWithBiasAndRelu",
225-
"QuantizedConv2DWithBiasSumAndRelu",
226-
"QuantizedConv2DWithBiasSignedSumAndRelu"
227-
};
228-
229-
if (std::find(fused_quantized_bias_ops.begin(),
230-
fused_quantized_bias_ops.end(),
231-
quantized_conv2D_node.op()) != fused_quantized_bias_ops.end()) {
232-
SetNodeAttr("Tbias", DT_FLOAT, &fused_conv);
233-
}
253+
std::vector<std::string> fused_quantized_bias_ops = {
254+
"QuantizedConv2DWithBias",
255+
"QuantizedConv2DWithBiasAndRelu",
256+
"QuantizedDepthwiseConv2DWithBias",
257+
"QuantizedDepthwiseConv2DWithBiasAndRelu",
258+
"QuantizedConv2DWithBiasSumAndRelu",
259+
};
260+
if (std::find(fused_quantized_bias_ops.begin(),
261+
fused_quantized_bias_ops.end(),
262+
quantized_conv2D_node.op()) != fused_quantized_bias_ops.end()) {
263+
SetNodeAttr("Tbias", DT_FLOAT, &fused_conv);
234264
}
235-
236-
CopyNodeAttr(quantized_conv2D_node, "Tinput", "Tinput", &fused_conv);
237-
CopyNodeAttr(quantized_conv2D_node, "Tfilter", "Tfilter", &fused_conv);
238-
CopyNodeAttr(quantized_conv2D_node, "strides", "strides", &fused_conv);
239-
CopyNodeAttr(quantized_conv2D_node, "padding", "padding", &fused_conv);
240-
241265
if (HasNodeAttr(quantized_conv2D_node, "padding_list"))
242266
CopyNodeAttr(quantized_conv2D_node, "padding_list",
243267
"padding_list", &fused_conv);
244-
245268
// Copy dilation attribute if exsit in the orginal node
246269
if (HasNodeAttr(quantized_conv2D_node, "dilations"))
247270
CopyNodeAttr(quantized_conv2D_node, "dilations",
@@ -259,93 +282,112 @@ Status FuseQuantizedConvolutionAndRequantize(
259282
},
260283
{}, &replaced_graph_def));
261284

262-
if (!is_perchannel) {
263-
// Convert bias float -> int32 on replaced_graph_def
264-
std::vector<std::string> fused_requantized_bias_ops = {
265-
"QuantizedConv2DWithBiasAndRequantize",
266-
"QuantizedConv2DWithBiasAndReluAndRequantize",
267-
"QuantizedConv2DWithBiasSumAndReluAndRequantize",
268-
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
269-
};
270-
271-
node_map.clear();
272-
MapNamesToNodes(replaced_graph_def, &node_map);
273-
for (auto& node_pair : node_map) {
274-
const NodeDef *node = node_pair.second;
275-
if (str_util::StartsWith(node->op(), "Dequantize")) {
276-
// dequant node should accept DT_QINT8 if the input node is
277-
// "QuantizedConv2DAndRequantize" and
278-
// "QuantizedConv2DWithBiasAndRequantize"
279-
std::string input_node_op =
280-
node_map[NodeNameFromInput(node->input(0))]->op();
281-
if (str_util::StartsWith(input_node_op,
282-
"QuantizedConv2DAndRequantize") ||
283-
str_util::StartsWith(input_node_op,
284-
"QuantizedConv2DWithBiasAndRequantize")) {
285-
SetNodeAttr("T", DT_QINT8, const_cast<NodeDef*>(node));
286-
SetNodeAttr("mode", "SCALED", const_cast<NodeDef*>(node));
287-
}
285+
// After Requantize op fusion, fix attributes for nodes in the graph,
286+
// if threre is some discrepency. And also quantize the bias (float -> int32)
287+
// List of requantize fused ops that have biases.
288+
std::vector<std::string> fused_requantized_bias_ops = {
289+
"QuantizedConv2DWithBiasAndRequantize",
290+
"QuantizedConv2DWithBiasAndReluAndRequantize",
291+
"QuantizedConv2DWithBiasSumAndReluAndRequantize",
292+
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
293+
"QuantizedDepthwiseConv2DWithBiasAndRequantize",
294+
"QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"
295+
};
296+
297+
node_map.clear();
298+
MapNamesToNodes(replaced_graph_def, &node_map);
299+
for (auto& node_pair : node_map) {
300+
const NodeDef *node = node_pair.second;
301+
// An workaround to fix attributes of "Dequantize" op with non-perchannel
302+
// quantization. "Dequantize" node should accept DT_QINT8 if the input node
303+
// is "QuantizedConv2DAndRequantize" or
304+
// "QuantizedConv2DWithBiasAndRequantize".
305+
if (str_util::StartsWith(node->op(), "Dequantize")) {
306+
std::string input_node_op =
307+
node_map[NodeNameFromInput(node->input(0))]->op();
308+
if (str_util::StartsWith(input_node_op,
309+
"QuantizedConv2DAndRequantize") ||
310+
str_util::StartsWith(input_node_op,
311+
"QuantizedConv2DWithBiasAndRequantize")) {
312+
SetNodeAttr("T", DT_QINT8, const_cast<NodeDef*>(node));
313+
SetNodeAttr("mode", "SCALED", const_cast<NodeDef*>(node));
314+
}
288315
continue;
289316
}
290-
317+
// Quantize bias to int32 if input min-max values are constants.
318+
// This is guaranteed if the preceeding op is a fused requantize op.
291319
bool is_fused_requantized_conv_op =
292-
std::find(fused_requantized_bias_ops.begin(),
293-
fused_requantized_bias_ops.end(), node->op())
294-
!= fused_requantized_bias_ops.end();
295-
if (is_fused_requantized_conv_op) {
296-
// If the op is feed by Quantize op then we keep bias as float
297-
std::string input_op = node_map[NodeNameFromInput(
298-
node->input(0))]->op();
299-
if (str_util::StartsWith(input_op, "QuantizedConv2D") &&
300-
str_util::EndsWith(input_op, "AndRequantize")) {
301-
NodeDef *bias_node = const_cast<NodeDef*>(node_map[NodeNameFromInput(
302-
node->input(2))]);
303-
const NodeDef *min_input_node = node_map[NodeNameFromInput(
320+
std::find(fused_requantized_bias_ops.begin(),
321+
fused_requantized_bias_ops.end(), node->op())
322+
!= fused_requantized_bias_ops.end();
323+
if (is_fused_requantized_conv_op) {
324+
std::string preceeding_op = node_map[NodeNameFromInput(
325+
node->input(0))]->op();
326+
if (str_util::StartsWith(preceeding_op, "Quantized") &&
327+
str_util::StrContains(preceeding_op, "Conv2D") &&
328+
str_util::EndsWith(preceeding_op, "AndRequantize")) {
329+
NodeDef *bias_node = const_cast<NodeDef*>(node_map[NodeNameFromInput(
330+
node->input(2))]);
331+
const NodeDef *min_input_node = node_map[NodeNameFromInput(
304332
node_map[node->input(0)]->input(7))];
305-
const NodeDef *max_input_node = node_map[NodeNameFromInput(
333+
const NodeDef *max_input_node = node_map[NodeNameFromInput(
306334
node_map[node->input(0)]->input(8))];
307-
const NodeDef *min_filter_node = node_map[NodeNameFromInput(
335+
const NodeDef *min_filter_node = node_map[NodeNameFromInput(
308336
node->input(5))];
309-
const NodeDef *max_filter_node = node_map[NodeNameFromInput(
337+
const NodeDef *max_filter_node = node_map[NodeNameFromInput(
310338
node->input(6))];
311-
const float min_input =
339+
const float min_input =
312340
GetNodeTensorAttr(*min_input_node, "value").flat<float>()(0);
313-
const float max_input =
341+
const float max_input =
314342
GetNodeTensorAttr(*max_input_node, "value").flat<float>()(0);
315-
const float min_filter =
316-
GetNodeTensorAttr(*min_filter_node, "value").flat<float>()(0);
317-
const float max_filter =
318-
GetNodeTensorAttr(*max_filter_node, "value").flat<float>()(0);
319-
320-
TensorProto float_tensor_proto =
321-
bias_node->attr().at("value").tensor();
322-
Tensor float_tensor;
323-
CHECK(float_tensor.FromProto(float_tensor_proto));
324-
CHECK_EQ(float_tensor.dtype(), DT_FLOAT);
325-
float *p_bias_float = float_tensor.flat<float>().data();
326-
327-
Tensor int32_tensor = Tensor(DT_QINT32, float_tensor.shape());
328-
qint32 *p_bias_int32 = int32_tensor.flat<qint32>().data();
329-
330-
float bias_scale = 255.0 * 127.0 /
343+
const Tensor& min_filter_tensor =
344+
GetNodeTensorAttr(*min_filter_node, "value");
345+
const Tensor& max_filter_tensor =
346+
GetNodeTensorAttr(*max_filter_node, "value");
347+
const float* min_filter = min_filter_tensor.flat<float>().data();
348+
const float* max_filter = max_filter_tensor.flat<float>().data();
349+
size_t num_scale_factors = min_filter_tensor.NumElements();
350+
351+
TensorProto float_tensor_proto =
352+
bias_node->attr().at("value").tensor();
353+
Tensor float_bias_tensor;
354+
CHECK(float_bias_tensor.FromProto(float_tensor_proto));
355+
CHECK_EQ(float_bias_tensor.dtype(), DT_FLOAT);
356+
float *float_bias = float_bias_tensor.flat<float>().data();
357+
358+
Tensor int32_bias_tensor = Tensor(DT_QINT32, float_bias_tensor.shape());
359+
qint32 *int32_bias = int32_bias_tensor.flat<qint32>().data();
360+
std::vector<float> scales(num_scale_factors);
361+
for (size_t i = 0; i < num_scale_factors; ++i) {
362+
scales[i] = 255.0 * 127.0 /
331363
(std::max(std::abs(max_input), std::abs(min_input)) *
332-
std::max(std::abs(max_filter), std::abs(min_filter)));
333-
int64 nelems = float_tensor.NumElements();
334-
for (int64 n = 0; n < nelems; n++)
335-
p_bias_int32[n] = (int32_t) (p_bias_float[n] * bias_scale);
336-
337-
bias_node->clear_attr();
338-
AttrValue attr_type;
339-
attr_type.set_type(int32_tensor.dtype());
340-
bias_node->mutable_attr()->insert({"dtype", attr_type});
341-
AttrValue attr_tensor;
342-
TensorProto* t = attr_tensor.mutable_tensor();
343-
int32_tensor.AsProtoTensorContent(t);
344-
bias_node->mutable_attr()->insert({"value", attr_tensor});
345-
SetNodeAttr("Tbias", DT_QINT32, const_cast<NodeDef*>(node));
364+
std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
365+
}
366+
int64 bias_length = float_bias_tensor.NumElements();
367+
if (num_scale_factors > 1) {
368+
if (bias_length != num_scale_factors) {
369+
return Status(error::Code::FAILED_PRECONDITION,
370+
"Number of filter output channels is not"
371+
"equal to bias size");
372+
} else {
373+
for (int64 i = 0; i < bias_length; i++)
374+
int32_bias[i] = (int32_t) (float_bias[i] * scales[i]);
375+
}
346376
} else {
347-
SetNodeAttr("Tbias", DT_FLOAT, const_cast<NodeDef*>(node));
377+
for (int64 i = 0; i < bias_length; i++)
378+
int32_bias[i] = (int32_t) (float_bias[i] * scales[0]);
348379
}
380+
bias_node->clear_attr();
381+
AttrValue attr_type;
382+
attr_type.set_type(int32_bias_tensor.dtype());
383+
bias_node->mutable_attr()->insert({"dtype", attr_type});
384+
AttrValue attr_tensor;
385+
TensorProto* t = attr_tensor.mutable_tensor();
386+
int32_bias_tensor.AsProtoTensorContent(t);
387+
bias_node->mutable_attr()->insert({"value", attr_tensor});
388+
SetNodeAttr("Tbias", DT_QINT32, const_cast<NodeDef*>(node));
389+
} else {
390+
SetNodeAttr("Tbias", DT_FLOAT, const_cast<NodeDef*>(node));
349391
}
350392
}
351393
}

0 commit comments

Comments
 (0)