Skip to content

Commit

Permalink
unify activations and tests (#551)
Browse files Browse the repository at this point in the history
* clean up activation/test
* test general properties for activations
  • Loading branch information
WindQAQ authored and seanpmorgan committed Oct 3, 2019
1 parent 8c94e2f commit 9e90311
Show file tree
Hide file tree
Showing 17 changed files with 113 additions and 154 deletions.
13 changes: 13 additions & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ py_library(
srcs_version = "PY2AND3",
)

py_test(
name = "activations_test",
size = "small",
srcs = [
"activations_test.py",
],
main = "activations_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)

py_test(
name = "sparsemax_test",
size = "small",
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ must:
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
decorator.
* Add a `py_test` to this sub-package's BUILD file.
* Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/activations_test.py) to test serialization.

#### Documentation Requirements
* Update the table of contents in this sub-package's README.
49 changes: 49 additions & 0 deletions tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons import activations
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink"
]

def test_serialization(self):
for name in self.ALL_ACTIVATIONS:
fn = tf.keras.activations.get(name)
ref_fn = getattr(activations, name)
self.assertEqual(fn, ref_fn)
config = tf.keras.activations.serialize(fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
for name in self.ALL_ACTIVATIONS:
layer = tf.keras.layers.Dense(
3, activation=getattr(activations, name))
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, name)
55 changes: 8 additions & 47 deletions tensorflow_addons/activations/gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,58 +19,33 @@

from absl.testing import parameterized

import math

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import gelu
from tensorflow_addons.utils import test_utils


def _ref_gelu(x, approximate=True):
x = tf.convert_to_tensor(x)
if approximate:
pi = tf.cast(math.pi, x.dtype)
coeff = tf.cast(0.044715, x.dtype)
return 0.5 * x * (
1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
else:
return 0.5 * x * (
1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))


@test_utils.run_all_in_graph_and_eager_modes
class GeluTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gelu(self, dtype):
x = np.random.rand(2, 3, 4).astype(dtype)
self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x))
self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False))
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977], dtype=dtype)
self.assertAllCloseAccordingToType(gelu(x), expected_result)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)

for approximate in [True, False]:
with self.subTest(approximate=approximate):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_gelu(x, approximate)
y = gelu(x, approximate)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)
expected_result = tf.constant(
[-0.04550028, -0.15865526, 0.0, 0.8413447, 1.9544997], dtype=dtype)
self.assertAllCloseAccordingToType(gelu(x, False), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)

for approximate in [True, False]:
with self.subTest(approximate=approximate):
Expand All @@ -87,20 +62,6 @@ def test_unknown_shape(self):
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), gelu(x))

def test_serialization(self):
ref_fn = gelu
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=gelu)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "gelu")


if __name__ == "__main__":
tf.test.main()
50 changes: 11 additions & 39 deletions tensorflow_addons/activations/hardshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
from tensorflow_addons.utils import test_utils


def _ref_hardshrink(x, lower=-1.0, upper=1.0):
x = tf.convert_to_tensor(x)
return tf.where(tf.math.logical_or(x < lower, x > upper), x, 0.0)


@test_utils.run_all_in_graph_and_eager_modes
class HardshrinkTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid(self):
Expand All @@ -42,34 +37,25 @@ def test_invalid(self):
("float32", np.float32),
("float64", np.float64))
def test_hardshrink(self, dtype):
x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype)
self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x))
self.assertAllCloseAccordingToType(
hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0))
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
self.assertAllCloseAccordingToType(hardshrink(x), expected_result)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)

with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_hardshrink(x)
y = hardshrink(x)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)
expected_result = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
self.assertAllCloseAccordingToType(
hardshrink(x, lower=-0.5, upper=0.5), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(
lambda x: hardshrink(x), [x])
# Hardshrink is not continuous at `lower` and `upper`.
# Avoid these two points to make gradients smooth.
x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
Expand All @@ -80,20 +66,6 @@ def test_unknown_shape(self):
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), hardshrink(x))

def test_serialization(self):
ref_fn = hardshrink
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=hardshrink)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "hardshrink")


if __name__ == "__main__":
tf.test.main()
13 changes: 0 additions & 13 deletions tensorflow_addons/activations/lisht_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,6 @@ def test_unknown_shape(self):
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), lisht(x))

def test_serialization(self):
config = tf.keras.activations.serialize(lisht)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, lisht)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=lisht)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "lisht")


if __name__ == "__main__":
tf.test.main()
12 changes: 5 additions & 7 deletions tensorflow_addons/activations/sparsemax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@keras_utils.register_keras_custom_object
@tf.function
def sparsemax(logits, axis=-1, name=None):
def sparsemax(logits, axis=-1):
"""Sparsemax activation function [1].
For each batch `i` and class `j` we have
Expand All @@ -35,7 +35,6 @@ def sparsemax(logits, axis=-1, name=None):
Args:
logits: Input tensor.
axis: Integer, axis along which the sparsemax operation is applied.
name: A name for the operation (optional).
Returns:
Tensor, output of sparsemax transformation. Has the same type and
shape as `logits`.
Expand All @@ -50,7 +49,7 @@ def sparsemax(logits, axis=-1, name=None):
is_last_axis = (axis == -1) or (axis == rank - 1)

if is_last_axis:
output = _compute_2d_sparsemax(logits, name=name)
output = _compute_2d_sparsemax(logits)
output.set_shape(shape)
return output

Expand All @@ -64,8 +63,7 @@ def sparsemax(logits, axis=-1, name=None):

# Do the actual softmax on its last dimension.
output = _compute_2d_sparsemax(logits)
output = _swap_axis(
output, axis_norm, tf.math.subtract(rank_op, 1), name=name)
output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1))

# Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
Expand All @@ -82,7 +80,7 @@ def _swap_axis(logits, dim_index, last_index, **kwargs):


@tf.function
def _compute_2d_sparsemax(logits, name=None):
def _compute_2d_sparsemax(logits):
"""Performs the sparsemax operation when axis=-1."""
shape_op = tf.shape(logits)
obs = tf.math.reduce_prod(shape_op[:-1])
Expand Down Expand Up @@ -134,5 +132,5 @@ def _compute_2d_sparsemax(logits, name=None):
logits.dtype)), p)

# Reshape back to original size
p_safe = tf.reshape(p_safe, shape_op, name=name)
p_safe = tf.reshape(p_safe, shape_op)
return p_safe
14 changes: 0 additions & 14 deletions tensorflow_addons/activations/sparsemax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,20 +274,6 @@ def test_gradient_against_estimate(self, dtype=None):
lambda logits: sparsemax(logits), [z], delta=1e-6)
self.assertAllCloseAccordingToType(jacob_sym, jacob_num)

def test_serialization(self, dtype=None):
ref_fn = sparsemax
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self, dtype=None):
layer = tf.keras.layers.Dense(3, activation=sparsemax)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "sparsemax")


if __name__ == '__main__':
tf.test.main()
36 changes: 13 additions & 23 deletions tensorflow_addons/activations/tanhshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,27 @@
from tensorflow_addons.utils import test_utils


def _ref_tanhshrink(x):
return x - tf.tanh(x)


@test_utils.run_all_in_graph_and_eager_modes
class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_tanhshrink(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x))
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[-1.0359724, -0.23840582, 0.0, 0.23840582, 1.0359724], dtype=dtype)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
self.assertAllCloseAccordingToType(tanhshrink(x), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_tanhshrink(x)
y = tanhshrink(x)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)

def test_serialization(self):
ref_fn = tanhshrink
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS);

#endif // GOOGLE_CUDA

} // end namespace addons
} // namespace tensorflow
} // namespace addons
} // namespace tensorflow
Loading

0 comments on commit 9e90311

Please sign in to comment.