From 3a0c22d778c6d7c06e81ab1771957deba7a3d52a Mon Sep 17 00:00:00 2001 From: Jaehong Kim Date: Tue, 17 Nov 2020 21:37:20 -0800 Subject: [PATCH] Add public API classes for quantization scheme. tfmot.quantization.keras.QuantizeScheme tfmot.quantization.keras.QuantizeRegistry tfmot.quantization.keras.QuantizeLayoutTransform tfmot.quantization.keras.default_8bit.Default8BitQuantizeScheme tfmot.quantization.keras.default_8bit.Default8BitQuantizeRegistry tfmot.quantization.keras.default_8bit.Default8BitQuantizeLayoutTransform PiperOrigin-RevId: 343009136 --- .../python/core/api/BUILD | 1 + .../core/api/quantization/keras/__init__.py | 6 +++++ .../keras/default_8bit/__init__.py | 23 +++++++++++++++++++ .../default_8bit_quantize_layout_transform.py | 2 +- .../default_8bit_quantize_registry.py | 3 ++- .../default_8bit_quantize_registry_test.py | 4 ++-- .../default_8bit_quantize_scheme.py | 5 ++-- .../default_8bit/default_8bit_transforms.py | 4 ++-- .../keras/quantize_wrapper_test.py | 2 +- 9 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 tensorflow_model_optimization/python/core/api/quantization/keras/default_8bit/__init__.py diff --git a/tensorflow_model_optimization/python/core/api/BUILD b/tensorflow_model_optimization/python/core/api/BUILD index b4e5ad56e..e37f81474 100644 --- a/tensorflow_model_optimization/python/core/api/BUILD +++ b/tensorflow_model_optimization/python/core/api/BUILD @@ -10,6 +10,7 @@ py_library( "clustering/keras/__init__.py", "quantization/__init__.py", "quantization/keras/__init__.py", + "quantization/keras/default_8bit/__init__.py", "quantization/keras/quantizers/__init__.py", "sparsity/__init__.py", "sparsity/keras/__init__.py", diff --git a/tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py b/tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py index 1864e70d2..964e99946 100644 --- a/tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py +++ b/tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py @@ -17,6 +17,7 @@ # submodules from tensorflow_model_optimization.python.core.api.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.api.quantization.keras import default_8bit # quantize all layers with default quantization implementation. from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model @@ -33,4 +34,9 @@ # Deserialize quantized model for Keras h5 format. from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_scope +# Quantization Scheme classes. +from tensorflow_model_optimization.python.core.quantization.keras.quantize_scheme import QuantizeScheme +from tensorflow_model_optimization.python.core.quantization.keras.quantize_layout_transform import QuantizeLayoutTransform +from tensorflow_model_optimization.python.core.quantization.keras.quantize_registry import QuantizeRegistry + # pylint: enable=g-bad-import-order diff --git a/tensorflow_model_optimization/python/core/api/quantization/keras/default_8bit/__init__.py b/tensorflow_model_optimization/python/core/api/quantization/keras/default_8bit/__init__.py new file mode 100644 index 000000000..b8b2a45d2 --- /dev/null +++ b/tensorflow_model_optimization/python/core/api/quantization/keras/default_8bit/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module containing 8bit default quantization scheme.""" +# pylint: disable=g-bad-import-order + +# The 8bit default quantization scheme classes. +from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_scheme import Default8BitQuantizeScheme +from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_layout_transform import Default8BitQuantizeLayoutTransform +from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import Default8BitQuantizeRegistry + +# pylint: enable=g-bad-import-order diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py index 2921c6621..0aa3b2ff7 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py @@ -27,7 +27,7 @@ keras = tf.keras -class QuantizeLayoutTransform( +class Default8BitQuantizeLayoutTransform( quantize_layout_transform.QuantizeLayoutTransform): """Default model transformations.""" diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py index ccf830f97..d11733ad1 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py @@ -69,7 +69,8 @@ def _get_rnn_cells(self, rnn_layer): return [rnn_layer.cell] -class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper): +class Default8BitQuantizeRegistry( + quantize_registry.QuantizeRegistry, _RNNHelper): """QuantizationRegistry for built-in Keras classes for default 8-bit scheme.""" # TODO(tfmot): expand layers test in quantize_functional_test.py diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py index 803476f0c..fec0d3c70 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py @@ -79,8 +79,8 @@ class QuantizeRegistryTest( def setUp(self): super(QuantizeRegistryTest, self).setUp() - self.quantize_registry = default_8bit_quantize_registry.QuantizeRegistry( - ) + self.quantize_registry = default_8bit_quantize_registry.\ + Default8BitQuantizeRegistry() class CustomLayer(l.Layer): pass diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_scheme.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_scheme.py index 5b324f3bb..1cda6bebf 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_scheme.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_scheme.py @@ -22,8 +22,9 @@ class Default8BitQuantizeScheme(quantize_scheme.QuantizeScheme): def get_layout_transformer(self): - return default_8bit_quantize_layout_transform.QuantizeLayoutTransform() + return default_8bit_quantize_layout_transform.\ + Default8BitQuantizeLayoutTransform() def get_quantize_registry(self): - return default_8bit_quantize_registry.QuantizeRegistry() + return default_8bit_quantize_registry.Default8BitQuantizeRegistry() diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py index b0c8e9d08..2b40d3dfe 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py @@ -584,8 +584,8 @@ def replacement(self, match_layer): concat_layer_node = match_layer feeding_layer_nodes = match_layer.input_layers - default_registry = default_8bit_quantize_registry.QuantizeRegistry( - ) + default_registry = default_8bit_quantize_registry.\ + Default8BitQuantizeRegistry() feed_quantize_configs = [] for feed_layer_node in feeding_layer_nodes: diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py index 169e1a7c4..7628d9be0 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py @@ -29,7 +29,7 @@ QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation QuantizeWrapper = quantize_wrapper.QuantizeWrapper -QuantizeRegistry = default_8bit_quantize_registry.QuantizeRegistry +QuantizeRegistry = default_8bit_quantize_registry.Default8BitQuantizeRegistry keras = tf.keras layers = tf.keras.layers