Skip to content

Commit

Permalink
Fix bug in "xy" mode for bilinear interpolation (#845)
Browse files Browse the repository at this point in the history
* Fix bug in "xy" mode for bilinear interpolation
* fixed error in interp implementation, and added test for non-square images.
* removed erroneous non-square test, and modified small_grid tests to be non-square.
  • Loading branch information
djl11 authored and seanpmorgan committed Jan 9, 2020
1 parent 06d686e commit 2b84e03
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
4 changes: 2 additions & 2 deletions tensorflow_addons/image/dense_image_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
index_order = [0, 1] if indexing == "ij" else [1, 0]
unstacked_query_points = tf.unstack(query_points, axis=2, num=2)

for dim in index_order:
for i, dim in enumerate(index_order):
with tf.name_scope("dim-" + str(dim)):
queries = unstacked_query_points[dim]

size_in_indexing_dimension = grid_shape[dim + 1]
size_in_indexing_dimension = grid_shape[i + 1]

# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
# is still a valid index into the grid.
Expand Down
24 changes: 15 additions & 9 deletions tensorflow_addons/image/dense_image_warp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,28 @@
@test_utils.run_all_in_graph_and_eager_modes
class InterpolateBilinearTest(tf.test.TestCase):
def test_interpolate_small_grid_ij(self):
grid = tf.constant([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]],
shape=[1, 3, 3, 1])
query_points = tf.constant([[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]],
shape=[1, 4, 2])
expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
grid = tf.constant(
[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.], [9., 10., 11.]],
shape=[1, 4, 3, 1])
query_points = tf.constant(
[[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5], [3., 2.]],
shape=[1, 5, 2])
expected_results = np.reshape(
np.array([0., 3., 6.5, 6., 11.]), [1, 5, 1])

interp = interpolate_bilinear(grid, query_points)

self.assertAllClose(expected_results, interp)

def test_interpolate_small_grid_xy(self):
grid = tf.constant([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]],
shape=[1, 3, 3, 1])
grid = tf.constant(
[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.], [9., 10., 11.]],
shape=[1, 4, 3, 1])
query_points = tf.constant(
[[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2])
expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
[[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5], [2., 3.]],
shape=[1, 5, 2])
expected_results = np.reshape(
np.array([0., 3., 6.5, 6., 11.]), [1, 5, 1])

interp = interpolate_bilinear(grid, query_points, indexing="xy")

Expand Down

0 comments on commit 2b84e03

Please sign in to comment.