Skip to content

Commit

Permalink
Add public API classes for quantization scheme.
Browse files Browse the repository at this point in the history
tfmot.quantization.keras.QuantizeScheme
tfmot.quantization.keras.QuantizeRegistry
tfmot.quantization.keras.QuantizeLayoutTransform

tfmot.quantization.keras.default_8bit.Default8BitQuantizeScheme
tfmot.quantization.keras.default_8bit.Default8BitQuantizeRegistry
tfmot.quantization.keras.default_8bit.Default8BitQuantizeLayoutTransform

PiperOrigin-RevId: 343009136
  • Loading branch information
Xhark authored and tensorflower-gardener committed Nov 18, 2020
1 parent dae21f6 commit 3a0c22d
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 9 deletions.
1 change: 1 addition & 0 deletions tensorflow_model_optimization/python/core/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ py_library(
"clustering/keras/__init__.py",
"quantization/__init__.py",
"quantization/keras/__init__.py",
"quantization/keras/default_8bit/__init__.py",
"quantization/keras/quantizers/__init__.py",
"sparsity/__init__.py",
"sparsity/keras/__init__.py",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# submodules
from tensorflow_model_optimization.python.core.api.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.api.quantization.keras import default_8bit

# quantize all layers with default quantization implementation.
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model
Expand All @@ -33,4 +34,9 @@
# Deserialize quantized model for Keras h5 format.
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_scope

# Quantization Scheme classes.
from tensorflow_model_optimization.python.core.quantization.keras.quantize_scheme import QuantizeScheme
from tensorflow_model_optimization.python.core.quantization.keras.quantize_layout_transform import QuantizeLayoutTransform
from tensorflow_model_optimization.python.core.quantization.keras.quantize_registry import QuantizeRegistry

# pylint: enable=g-bad-import-order
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module containing 8bit default quantization scheme."""
# pylint: disable=g-bad-import-order

# The 8bit default quantization scheme classes.
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_scheme import Default8BitQuantizeScheme
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_layout_transform import Default8BitQuantizeLayoutTransform
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import Default8BitQuantizeRegistry

# pylint: enable=g-bad-import-order
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
keras = tf.keras


class QuantizeLayoutTransform(
class Default8BitQuantizeLayoutTransform(
quantize_layout_transform.QuantizeLayoutTransform):
"""Default model transformations."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _get_rnn_cells(self, rnn_layer):
return [rnn_layer.cell]


class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
class Default8BitQuantizeRegistry(
quantize_registry.QuantizeRegistry, _RNNHelper):
"""QuantizationRegistry for built-in Keras classes for default 8-bit scheme."""

# TODO(tfmot): expand layers test in quantize_functional_test.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class QuantizeRegistryTest(

def setUp(self):
super(QuantizeRegistryTest, self).setUp()
self.quantize_registry = default_8bit_quantize_registry.QuantizeRegistry(
)
self.quantize_registry = default_8bit_quantize_registry.\
Default8BitQuantizeRegistry()

class CustomLayer(l.Layer):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
class Default8BitQuantizeScheme(quantize_scheme.QuantizeScheme):

def get_layout_transformer(self):
return default_8bit_quantize_layout_transform.QuantizeLayoutTransform()
return default_8bit_quantize_layout_transform.\
Default8BitQuantizeLayoutTransform()

def get_quantize_registry(self):
return default_8bit_quantize_registry.QuantizeRegistry()
return default_8bit_quantize_registry.Default8BitQuantizeRegistry()

Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ def replacement(self, match_layer):
concat_layer_node = match_layer
feeding_layer_nodes = match_layer.input_layers

default_registry = default_8bit_quantize_registry.QuantizeRegistry(
)
default_registry = default_8bit_quantize_registry.\
Default8BitQuantizeRegistry()

feed_quantize_configs = []
for feed_layer_node in feeding_layer_nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
QuantizeWrapper = quantize_wrapper.QuantizeWrapper
QuantizeRegistry = default_8bit_quantize_registry.QuantizeRegistry
QuantizeRegistry = default_8bit_quantize_registry.Default8BitQuantizeRegistry

keras = tf.keras
layers = tf.keras.layers
Expand Down

0 comments on commit 3a0c22d

Please sign in to comment.