Skip to content

Commit

Permalink
[JAX] add support for gather/scatter batching dims following the new …
Browse files Browse the repository at this point in the history
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
  • Loading branch information
tomnatan30 authored and Google-ML-Automation committed Sep 25, 2024
1 parent 1fe0c5d commit 1d35b27
Show file tree
Hide file tree
Showing 10 changed files with 832 additions and 236 deletions.
4 changes: 4 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ py_library(
srcs = ["_src/internal_test_util/test_harnesses.py"],
visibility = [":internal"] + jax_internal_test_harnesses_visibility,
deps = [
":ad_util",
":config",
":jax",
":test_util",
"//jax/_src/lib",
] + py_deps("numpy"),
)

Expand Down
104 changes: 74 additions & 30 deletions jax/_src/internal_test_util/test_harnesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,18 @@ def _make_broadcast_in_dim_harness(name,
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0, 1)), (1, 3), True),
((2, 5), np.array([[[0], [2]], [[1], [1]]]),
lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(1,),
start_index_map=(1,), operand_batching_dims=(0,),
start_indices_batching_dims=(0,)),
(1, 1), True),
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
lax.GatherDimensionNumbers(
offset_dims=(2,), collapsed_slice_dims=(),
start_index_map=(2,), operand_batching_dims=(0, 1),
start_indices_batching_dims=(1, 0)),
(1, 1, 3), True)
]:
dtype = np.float32
for enable_xla in ([True] if needs_xla else [True, False]):
Expand Down Expand Up @@ -1276,15 +1288,16 @@ def _make_scatter_harness(name,
update_shape=(2,),
mode=lax.GatherScatterMode.FILL_OR_DROP,
dtype=np.float32,
dimension_numbers=((), (0,), (0,)),
dimension_numbers=lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)),
enable_and_disable_xla=False):
dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers)
xla_options = [True, False] if enable_and_disable_xla else [True]

for enable_xla in xla_options:
define(
f_lax.__name__,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}"
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_{dimension_numbers=}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}"
.replace(" ", ""),
partial(
f_lax,
Expand Down Expand Up @@ -1328,8 +1341,19 @@ def _make_scatter_harness(name,

# Validate shapes, dimension numbers and scatter indices. All are in bounds.
for shape, scatter_indices, update_shape, dimension_numbers in [
((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))),
((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,)))
((10,), [[0], [0], [0]], (3, 2),
lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), [[0], [2], [1]], (3, 3),
lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((2, 3, 10), [[[0], [1]], [[2], [3]], [[4], [5]]], (3, 2, 3),
lax.ScatterDimensionNumbers(
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]:
_make_scatter_harness(
"shapes_and_dimension_numbers",
Expand Down Expand Up @@ -1358,36 +1382,51 @@ def _make_scatter_harness(name,
_make_scatter_harness("modes_in_bounds",
f_lax=f_lax,
mode=mode)
_make_scatter_harness("modes_out_of_bounds", mode=mode,
shape=(1, 5),
f_lax=f_lax,
scatter_indices=np.array([10]),
update_shape=(1,),
dimension_numbers=((0,), (1,), (1,)),
enable_and_disable_xla=True)
_make_scatter_harness(
"modes_out_of_bounds",
mode=mode,
shape=(1, 5),
f_lax=f_lax,
scatter_indices=np.array([10]),
update_shape=(1,),
dimension_numbers=lax.ScatterDimensionNumbers((0,), (1,), (1,)),
enable_and_disable_xla=True,
)

# Validate no XLA scatters
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean):
for f_lax in [
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter
]:
for shape, scatter_indices, update_shape, dimension_numbers in [
((1,), [0], (), ((), (0,), (0,))), # zero case
((1, 1), [0], (1,), ((0,), (0,), (0,))),
((1, 1, 1), [0], (1, 1), ((0, 1), (0,), (0,))),
((1, 50, 3), [32], (1, 3), ((0, 1), (1,), (1,))),
((1, 2, 3), [1], (1, 3), ((0, 1), (1,), (1,))), # slice 2nd dim
((1, 2, 3), [0], (2, 3), ((0, 1), (0,), (0,))), # slice 1st dim
((1, 2, 3), [1, 2], (1,), ((0,), (1, 2), (1, 2))), # 2nd and 3rd
((4, 2, 3), [3, 2], (2,), ((0,), (0, 2), (0, 2))), # 1st and 3rd
((4, 2, 3, 5), [0, 4], (4, 3), ((0, 1), (1, 3), (1, 3))), # 2nd and 4th
((1,), [0], (),
lax.ScatterDimensionNumbers((), (0,), (0,))), # zero case
((1, 1), [0], (1,),
lax.ScatterDimensionNumbers((0,), (0,), (0,))),
((1, 1, 1), [0], (1, 1),
lax.ScatterDimensionNumbers((0, 1), (0,), (0,))),
((1, 50, 3), [32], (1, 3),
lax.ScatterDimensionNumbers((0, 1), (1,), (1,))),
((1, 2, 3), [1], (1, 3),
lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), # slice 2nd dim
((1, 2, 3), [0], (2, 3),
lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), # slice 1st dim
((1, 2, 3), [1, 2], (1,),
lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), # 2nd and 3rd
((4, 2, 3), [3, 2], (2,),
lax.ScatterDimensionNumbers((0,), (0, 2), (0, 2))), # 1st and 3rd
((4, 2, 3, 5), [0, 4], (4, 3),
lax.ScatterDimensionNumbers((0, 1), (1, 3), (1, 3))), # 2nd and 4th
((5, 6, 7), [[0, 1], [2, 3]], (2, 7),
((1,), (0, 1), (0, 1))), # .at[((3,4),(5,5))] shapes
lax.ScatterDimensionNumbers((1,), (0, 1), (0, 1))),
# .at[((3,4),(5,5))] shapes
((5, 6, 7), [[[0], [1]], [[2], [3]]], (5, 2, 2, 7),
((0, 3), (1,), (1,))), # .at[:, ((3,4),(5,5))] shapes
lax.ScatterDimensionNumbers((0, 3), (1,), (1,))),
# .at[:, ((3,4),(5,5))] shapes
((5, 6, 7), [[[0, 1], [2, 3]], [[4, 0], [1, 2]]], (5, 2, 2),
((0,), (1, 2), (1, 2))), # .at[:, ((3,4),(5,5)), 3] shapes
((1, 125), [0], (1,), ((0,), (1,), (1,))),
lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))),
# .at[:, ((3,4),(5,5)), 3] shapes
((1, 125), [0], (1,), lax.ScatterDimensionNumbers((0,), (1,), (1,))),
]:
for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS,
lax.GatherScatterMode.FILL_OR_DROP):
Expand All @@ -1410,11 +1449,16 @@ def _make_scatter_harness(name,
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
]:
for shape, scatter_indices, update_shape, dimension_numbers in [
((1,), [[0],[0]], (2,), ((), (0,), (0,))), # .at[((0,0),)]
((3,), [[1],[0],[1]], (3,), ((), (0,), (0,))), # .at[((1,0,1),)]
((2, 3), [[[2],[2],[2]]], (2, 1, 3), ((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)]
((3, 5, 40), [[1],[1]], (3, 5, 2), ((0, 1), (2,), (2,))),
((3, 5, 4), [[1],[1]], (3, 2, 4), ((0, 2), (1,), (1,))),
((1,), [[0],[0]], (2,),
lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((0,0),)]
((3,), [[1],[0],[1]], (3,),
lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((1,0,1),)]
((2, 3), [[[2],[2],[2]]], (2, 1, 3),
lax.ScatterDimensionNumbers((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)]
((3, 5, 40), [[1],[1]], (3, 5, 2),
lax.ScatterDimensionNumbers((0, 1), (2,), (2,))),
((3, 5, 4), [[1],[1]], (3, 2, 4),
lax.ScatterDimensionNumbers((0, 2), (1,), (1,))),
]:
for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS,
lax.GatherScatterMode.FILL_OR_DROP):
Expand Down
17 changes: 7 additions & 10 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4654,18 +4654,15 @@ def _top_k_jvp(primals, tangents, *, k):
idx_shape = k_idxs.shape
rank = len(idx_shape)
gather_index_shape = idx_shape + (1,)
gather_indices = []
for i in range(rank-1):
_iota = iota(k_idxs.dtype, idx_shape[i])
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
gather_indices.append(_iota)
gather_indices.append(reshape(k_idxs, gather_index_shape))
gather_indices = concatenate(gather_indices, dimension=rank)
gather_indices = reshape(k_idxs, gather_index_shape)
slice_sizes = (1,) * rank
dnums = slicing.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=tuple(range(rank)),
start_index_map=tuple(range(rank)))
offset_dims=(),
collapsed_slice_dims=(rank - 1,),
operand_batching_dims=tuple(range(rank - 1)),
start_indices_batching_dims=tuple(range(rank - 1)),
start_index_map=(rank - 1,),
)
tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes)
return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1]))

Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,8 @@ def _pgather_impl(src, idx, *, axes):
dnums = slicing.GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=(0,),
start_index_map=(0,))
start_index_map=(0,),
)
return slicing.gather(src_one_axis_front, idx, dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes))

Expand Down
Loading

0 comments on commit 1d35b27

Please sign in to comment.