diff --git a/jax/BUILD b/jax/BUILD index c25d0004e772..d49e783e61d6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -147,7 +147,11 @@ py_library( srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, deps = [ + ":ad_util", + ":config", ":jax", + ":test_util", + "//jax/_src/lib", ] + py_deps("numpy"), ) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 2c94907568d9..31c3fec94536 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -1169,6 +1169,18 @@ def _make_broadcast_in_dim_harness(name, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3), True), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1), True), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3), True) ]: dtype = np.float32 for enable_xla in ([True] if needs_xla else [True, False]): @@ -1276,15 +1288,16 @@ def _make_scatter_harness(name, update_shape=(2,), mode=lax.GatherScatterMode.FILL_OR_DROP, dtype=np.float32, - dimension_numbers=((), (0,), (0,)), + dimension_numbers=lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), enable_and_disable_xla=False): - dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers) xla_options = [True, False] if enable_and_disable_xla else [True] for enable_xla in xla_options: define( f_lax.__name__, - f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" + f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_{dimension_numbers=}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" .replace(" ", ""), partial( f_lax, @@ -1328,8 +1341,19 @@ def _make_scatter_harness(name, # Validate shapes, dimension numbers and scatter indices. All are in bounds. for shape, scatter_indices, update_shape, dimension_numbers in [ - ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))), - ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,))) + ((10,), [[0], [0], [0]], (3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5), [[0], [2], [1]], (3, 3), + lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((2, 3, 10), [[[0], [1]], [[2], [3]], [[4], [5]]], (3, 2, 3), + lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]: _make_scatter_harness( "shapes_and_dimension_numbers", @@ -1358,13 +1382,16 @@ def _make_scatter_harness(name, _make_scatter_harness("modes_in_bounds", f_lax=f_lax, mode=mode) - _make_scatter_harness("modes_out_of_bounds", mode=mode, - shape=(1, 5), - f_lax=f_lax, - scatter_indices=np.array([10]), - update_shape=(1,), - dimension_numbers=((0,), (1,), (1,)), - enable_and_disable_xla=True) + _make_scatter_harness( + "modes_out_of_bounds", + mode=mode, + shape=(1, 5), + f_lax=f_lax, + scatter_indices=np.array([10]), + update_shape=(1,), + dimension_numbers=lax.ScatterDimensionNumbers((0,), (1,), (1,)), + enable_and_disable_xla=True, + ) # Validate no XLA scatters for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean): @@ -1372,22 +1399,34 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [0], (), ((), (0,), (0,))), # zero case - ((1, 1), [0], (1,), ((0,), (0,), (0,))), - ((1, 1, 1), [0], (1, 1), ((0, 1), (0,), (0,))), - ((1, 50, 3), [32], (1, 3), ((0, 1), (1,), (1,))), - ((1, 2, 3), [1], (1, 3), ((0, 1), (1,), (1,))), # slice 2nd dim - ((1, 2, 3), [0], (2, 3), ((0, 1), (0,), (0,))), # slice 1st dim - ((1, 2, 3), [1, 2], (1,), ((0,), (1, 2), (1, 2))), # 2nd and 3rd - ((4, 2, 3), [3, 2], (2,), ((0,), (0, 2), (0, 2))), # 1st and 3rd - ((4, 2, 3, 5), [0, 4], (4, 3), ((0, 1), (1, 3), (1, 3))), # 2nd and 4th + ((1,), [0], (), + lax.ScatterDimensionNumbers((), (0,), (0,))), # zero case + ((1, 1), [0], (1,), + lax.ScatterDimensionNumbers((0,), (0,), (0,))), + ((1, 1, 1), [0], (1, 1), + lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), + ((1, 50, 3), [32], (1, 3), + lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), + ((1, 2, 3), [1], (1, 3), + lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), # slice 2nd dim + ((1, 2, 3), [0], (2, 3), + lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), # slice 1st dim + ((1, 2, 3), [1, 2], (1,), + lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), # 2nd and 3rd + ((4, 2, 3), [3, 2], (2,), + lax.ScatterDimensionNumbers((0,), (0, 2), (0, 2))), # 1st and 3rd + ((4, 2, 3, 5), [0, 4], (4, 3), + lax.ScatterDimensionNumbers((0, 1), (1, 3), (1, 3))), # 2nd and 4th ((5, 6, 7), [[0, 1], [2, 3]], (2, 7), - ((1,), (0, 1), (0, 1))), # .at[((3,4),(5,5))] shapes + lax.ScatterDimensionNumbers((1,), (0, 1), (0, 1))), + # .at[((3,4),(5,5))] shapes ((5, 6, 7), [[[0], [1]], [[2], [3]]], (5, 2, 2, 7), - ((0, 3), (1,), (1,))), # .at[:, ((3,4),(5,5))] shapes + lax.ScatterDimensionNumbers((0, 3), (1,), (1,))), + # .at[:, ((3,4),(5,5))] shapes ((5, 6, 7), [[[0, 1], [2, 3]], [[4, 0], [1, 2]]], (5, 2, 2), - ((0,), (1, 2), (1, 2))), # .at[:, ((3,4),(5,5)), 3] shapes - ((1, 125), [0], (1,), ((0,), (1,), (1,))), + lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), + # .at[:, ((3,4),(5,5)), 3] shapes + ((1, 125), [0], (1,), lax.ScatterDimensionNumbers((0,), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): @@ -1410,11 +1449,16 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [[0],[0]], (2,), ((), (0,), (0,))), # .at[((0,0),)] - ((3,), [[1],[0],[1]], (3,), ((), (0,), (0,))), # .at[((1,0,1),)] - ((2, 3), [[[2],[2],[2]]], (2, 1, 3), ((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] - ((3, 5, 40), [[1],[1]], (3, 5, 2), ((0, 1), (2,), (2,))), - ((3, 5, 4), [[1],[1]], (3, 2, 4), ((0, 2), (1,), (1,))), + ((1,), [[0],[0]], (2,), + lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((0,0),)] + ((3,), [[1],[0],[1]], (3,), + lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((1,0,1),)] + ((2, 3), [[[2],[2],[2]]], (2, 1, 3), + lax.ScatterDimensionNumbers((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] + ((3, 5, 40), [[1],[1]], (3, 5, 2), + lax.ScatterDimensionNumbers((0, 1), (2,), (2,))), + ((3, 5, 4), [[1],[1]], (3, 2, 4), + lax.ScatterDimensionNumbers((0, 2), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7226ea25922b..e28d0857d624 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4654,18 +4654,15 @@ def _top_k_jvp(primals, tangents, *, k): idx_shape = k_idxs.shape rank = len(idx_shape) gather_index_shape = idx_shape + (1,) - gather_indices = [] - for i in range(rank-1): - _iota = iota(k_idxs.dtype, idx_shape[i]) - _iota = broadcast_in_dim(_iota, gather_index_shape, (i,)) - gather_indices.append(_iota) - gather_indices.append(reshape(k_idxs, gather_index_shape)) - gather_indices = concatenate(gather_indices, dimension=rank) + gather_indices = reshape(k_idxs, gather_index_shape) slice_sizes = (1,) * rank dnums = slicing.GatherDimensionNumbers( - offset_dims=(), - collapsed_slice_dims=tuple(range(rank)), - start_index_map=tuple(range(rank))) + offset_dims=(), + collapsed_slice_dims=(rank - 1,), + operand_batching_dims=tuple(range(rank - 1)), + start_indices_batching_dims=tuple(range(rank - 1)), + start_index_map=(rank - 1,), + ) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c9a07072ddc7..9d4614f344fb 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1500,7 +1500,8 @@ def _pgather_impl(src, idx, *, axes): dnums = slicing.GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=(0,), - start_index_map=(0,)) + start_index_map=(0,), + ) return slicing.gather(src_one_axis_front, idx, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 60dfa0e1b3d2..372ebd1a8389 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -233,6 +233,16 @@ class GatherDimensionNumbers(NamedTuple): start_index_map: for each dimension in `start_indices`, gives the corresponding dimension in the `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. + operand_batching_dims: the set of batching dimensions `i` in `operand` that + have `slice_sizes[i] == 1` and that should have a corresponding dimension + in both the `start_indices` (at the same index in + `start_indices_batching_dims`) and output of the gather. Must be a tuple + of integers in ascending order. + start_indices_batching_dims: the set of batching dimensions `i` in + `start_indices` that should have a corresponding dimension in both the + `operand` (at the same index in `operand_batching_dims`) and output of the + gather. Must be a tuple of integers (order is fixed based on + correspondence with `operand_batching_dims`). Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -241,6 +251,8 @@ class GatherDimensionNumbers(NamedTuple): offset_dims: tuple[int, ...] collapsed_slice_dims: tuple[int, ...] start_index_map: tuple[int, ...] + operand_batching_dims: tuple[int, ...] = () + start_indices_batching_dims: tuple[int, ...] = () class GatherScatterMode(enum.Enum): @@ -370,6 +382,17 @@ class ScatterDimensionNumbers(NamedTuple): scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives the corresponding dimension in `operand`. Must be a sequence of integers with size equal to `scatter_indices.shape[-1]`. + operand_batching_dims: the set of batching dimensions `i` in `operand` that + should have a corresponding dimension in both the `scatter_indices` (at + the same index in `scatter_indices_batching_dims`) and `updates`. Must be + a tuple of integers in ascending order. These are the mirror image of + `operand_batching_dims` in the case of `gather`. + scatter_indices_batching_dims: the set of batching dimensions `i` in + `scatter_indices` that should have a corresponding dimension in both the + `operand` (at the same index in `operand_batching_dims`) and output of the + gather. Must be a tuple of integers (order is fixed based on + correspondence with `input_batching_dims`). These are the mirror image of + `start_indices_batching_dims` in the case of `gather`. Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -378,6 +401,8 @@ class ScatterDimensionNumbers(NamedTuple): update_window_dims: Sequence[int] inserted_window_dims: Sequence[int] scatter_dims_to_operand_dims: Sequence[int] + operand_batching_dims: Sequence[int] = () + scatter_indices_batching_dims: Sequence[int] = () def scatter_add( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, @@ -694,7 +719,8 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=tuple(axes), - start_index_map=tuple(axes)) + start_index_map=tuple(axes), + ) return gather(src, indices, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) @@ -1256,8 +1282,11 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): dims = tuple(range(ndims)) start_indices, dyn_slice_sizes = util.split_list(start_indices_and_dyn, [ndims]) start_idx_bds, dyn_slice_size_bds = util.split_list(start_idx_and_dyn_bds, [ndims]) - dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), - start_index_map=dims) + dnums = GatherDimensionNumbers( + offset_dims=dims, + collapsed_slice_dims=(), + start_index_map=dims, + ) index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) return _gather_batching_rule( [operand, index, *dyn_slice_sizes], @@ -1396,9 +1425,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): update_shape = (np.shape(update) if update_bd is batching.not_mapped else tuple(np.delete(np.shape(update), update_bd))) dims = tuple(range(len(update_shape))) - dnums = ScatterDimensionNumbers(update_window_dims=dims, - inserted_window_dims=(), - scatter_dims_to_operand_dims=dims) + dnums = ScatterDimensionNumbers( + update_window_dims=dims, + inserted_window_dims=(), + scatter_dims_to_operand_dims=dims, + ) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) return api.vmap( partial(scatter, dimension_numbers=dnums, @@ -1437,6 +1468,12 @@ def _is_sorted(dims, op_name, name): if dims[i] < dims[i - 1]: raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}") +def _dims_in_range(dims, rank, op_name, name): + for dim in dims: + if dim < 0 or dim >= rank: + raise TypeError(f"Invalid {name} set in {op_name} op; valid range is " + f"[0, {rank}); got: {dim}.") + def _sorted_dims_in_range(dims, rank, op_name, name): if len(dims) == 0: return @@ -1453,6 +1490,11 @@ def _no_duplicate_dims(dims, op_name, name): if len(set(dims)) != len(dims): raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.") +def _disjoint_dims(dims1, dims2, op_name, name1, name2): + if not set(dims1).isdisjoint(set(dims2)): + raise TypeError(f"{name1} and {name2} in {op_name} op must be disjoint; " + f"got: {dims1} and {dims2}.") + def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1466,6 +1508,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + start_indices_batching_dims = dimension_numbers.start_indices_batching_dims start_index_map = dimension_numbers.start_index_map # Note: in JAX, index_vector_dim is always computed as below, cf. the @@ -1521,6 +1565,50 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") + + _no_duplicate_dims(operand_batching_dims, "gather", "operand_batching_dims") + _is_sorted(operand_batching_dims, "gather", "operand_batching_dims") + _sorted_dims_in_range( + operand_batching_dims, _rank(operand), "gather", "operand_batching_dims" + ) + + _disjoint_dims(collapsed_slice_dims, operand_batching_dims, "gather", + "collapsed_slice_dims", "operand_batching_dims") + _disjoint_dims(start_index_map, operand_batching_dims, "gather", + "start_index_map", "operand_batching_dims") + + _no_duplicate_dims( + start_indices_batching_dims, "gather", "start_indices_batching_dims" + ) + _dims_in_range( + start_indices_batching_dims, + _rank(indices), + "gather", + "start_indices_batching_dims", + ) + if index_vector_dim in start_indices_batching_dims: + raise TypeError( + "Gather op cannot have the index vector dimension as a batching " + f"dimension; got {start_indices_batching_dims}." + ) + + if len(operand_batching_dims) != len(start_indices_batching_dims): + raise TypeError( + "Gather op requires equal numbers of operand_batching_dims and " + f"start_indices_batching_dims, got {operand_batching_dims} and" + f"{start_indices_batching_dims}." + ) + + operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) + indices_batch_shape = tuple( + indices.shape[i] for i in start_indices_batching_dims + ) + if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): + raise TypeError( + "Gather op requires operand batching dimensions and indices batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{indices_batch_shape}." + ) # End ValidateGatherDimensions if _rank(operand) != len(slice_sizes): @@ -1528,12 +1616,17 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " f"input_shape.rank={_rank(operand)}") - if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): - raise TypeError(f"All components of the offset index in a gather op must " - f"either be a offset dimension or explicitly collapsed; " - f"got len(slice_sizes)={len(slice_sizes)}, " - f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" - f"{collapsed_slice_dims}.") + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims) + len( + operand_batching_dims + ): + raise TypeError( + "All components of the offset index in a gather op must " + "either be a offset dimension or explicitly collapsed/batching; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}, operand_batching_dims=" + f"{operand_batching_dims}." + ) for i in range(len(slice_sizes)): slice_size = slice_sizes[i] @@ -1552,12 +1645,21 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"but bound is {bound} for index " f"{collapsed_slice_dims[i]} at position {i}.") + for i in range(len(operand_batching_dims)): + bound = slice_sizes[operand_batching_dims[i]] + if bound > 1: + raise TypeError(f"Gather op can only have operand batching dims with " + f"bound 0/1, but bound is {bound} for index " + f"{operand_batching_dims[i]} at position {i}." + ) + return _gather_shape_computation(indices, dimension_numbers, slice_sizes) def _gather_shape_computation(indices, dimension_numbers, slice_sizes): offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims output_shape_rank = len(offset_dims) + _rank(indices) - 1 index_vector_dim = _rank(indices) - 1 @@ -1572,8 +1674,11 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): indices_shape_gen = iter(expanded_indices_shape) - slice_sizes_gen = (s for i, s in enumerate(slice_sizes) - if i not in collapsed_slice_dims) + slice_sizes_gen = ( + s + for i, s in enumerate(slice_sizes) + if i not in collapsed_slice_dims and i not in operand_batching_dims + ) ans = tuple(next(slice_sizes_gen) if i in offset_dims else next(indices_shape_gen) for i in range(output_shape_rank)) return ans @@ -1631,9 +1736,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, else: zeros = lax.full(operand_shape, lax._zero(t)) scatter_dnums = ScatterDimensionNumbers( - update_window_dims=dimension_numbers.offset_dims, - inserted_window_dims=dimension_numbers.collapsed_slice_dims, - scatter_dims_to_operand_dims=dimension_numbers.start_index_map) + update_window_dims=dimension_numbers.offset_dims, + inserted_window_dims=dimension_numbers.collapsed_slice_dims, + scatter_dims_to_operand_dims=dimension_numbers.start_index_map, + operand_batching_dims=dimension_numbers.operand_batching_dims, + scatter_indices_batching_dims=dimension_numbers.start_indices_batching_dims, + ) out = scatter_add(zeros, indices, t, scatter_dnums, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, @@ -1652,11 +1760,17 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + operand_batching_dims = tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map) + start_index_map=start_index_map, + operand_batching_dims=operand_batching_dims, + start_indices_batching_dims=dimension_numbers.start_indices_batching_dims, + ) if isinstance(operand_bdim, batching.RaggedAxis): ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes) for orig, fabricated in zip( @@ -1687,10 +1801,16 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, elif operand_bdim is None and indices_bdim is not None: indices = batching.moveaxis(indices, indices_bdim, 0) offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims) + start_indices_batching_dims = tuple( + np.add(1, dimension_numbers.start_indices_batching_dims) + ) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, - start_index_map=dimension_numbers.start_index_map) + start_index_map=dimension_numbers.start_index_map, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=start_indices_batching_dims, + ) # If batching indexed accesses into the same array, the batched gather may # no longer have sorted or unique indices. return gather(operand, indices, dimension_numbers=dnums, @@ -1702,61 +1822,34 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, operand = batching.moveaxis(operand, operand_bdim, 0) indices = batching.moveaxis(indices, indices_bdim, 0) - # This slightly awkward special case is needed because the shape rule for - # gather does not allow size-1 slices out of a size-0 dimension, even if - # the number of slices is zero. Likely the best fix would be to change the - # definition of gather() so it can be batched without the construction of - # an explicit iota of size-1 slices. if core.definitely_equal(operand.shape[0], 0): - output_shape = _gather_shape_rule( - core.ShapedArray(operand.shape[1:], operand.dtype), - core.ShapedArray(indices.shape[1:], - dtypes.canonicalize_dtype(indices.dtype)), - dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, - unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, - mode=mode, fill_value=fill_value) - return lax.full((0,) + output_shape, lax._zero(operand)), 0 - - # Example: user code had indices shape (3, 4, 5), and we have to deal with - # indices shape (7, 3, 4, 5). We transform that to indices of shape - # (7, 3, 4, 6) where we concatenated an iota that counts along our batch - # dimension to the front of the ndindex. - index_dtype = _promote_dtype_for_size(indices.dtype, indices.shape[0]) - count_shape = list(indices.shape) - count_shape[-1] = 1 - counts = lax.broadcasted_iota(index_dtype, tuple(count_shape), 0) - indices = lax.concatenate([counts, indices.astype(index_dtype)], - len(count_shape) - 1) - - slice_sizes = (1,) + slice_sizes - collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + slice_sizes = (0,) + slice_sizes + else: + slice_sizes = (1,) + slice_sizes + collapsed_slice_dims = tuple( + np.add(1, dimension_numbers.collapsed_slice_dims) + ) + operand_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) + start_indices_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.start_indices_batching_dims) + ) offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) - start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) + start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map) + start_index_map=start_index_map, + operand_batching_dims=operand_batching_dims, + start_indices_batching_dims=start_indices_batching_dims, + ) return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 -def _promote_dtype_for_size(dtype, size): - if not dtypes.issubdtype(dtype, np.integer): - return dtype - # size may be a dynamic shape, in which case we return at least int32 - try: - size = int(size) - except: - return dtype if np.iinfo(dtype).bits >= 32 else np.dtype('int32') - if size <= np.iinfo(dtype).max: - return dtype - elif size <= np.iinfo(np.int32).max: - return np.dtype('int32') - else: - return dtypes.canonicalize_dtype(np.int64) - def _gather_pad_rule(in_avals, out_avals, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1821,8 +1914,10 @@ def _gather_lower(ctx, operand, indices, *, GatherScatterMode.CLIP), mode dnums = hlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - operand_batching_dims=[], - start_indices_batching_dims=[], + operand_batching_dims=list(dimension_numbers.operand_batching_dims), + start_indices_batching_dims=list( + dimension_numbers.start_indices_batching_dims + ), index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map), @@ -1872,6 +1967,8 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_window_dims = dimension_numbers.update_window_dims inserted_window_dims = dimension_numbers.inserted_window_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + scatter_indices_batching_dims = dimension_numbers.scatter_indices_batching_dims scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the ScatterDimensionNumbers class. @@ -1909,8 +2006,55 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter", "inserted_window_dims") + # Validate operand_batching_dims and scatter_indices_batching_dims + _is_sorted(operand_batching_dims, "scatter", "operand_batching_dims") + _no_duplicate_dims(operand_batching_dims, "scatter", "operand_batching_dims") + _sorted_dims_in_range( + operand_batching_dims, _rank(operand), "scatter", "operand_batching_dims" + ) + _disjoint_dims(inserted_window_dims, operand_batching_dims, "scatter", + "inserted_window_dims", "operand_batching_dims") + _disjoint_dims(scatter_dims_to_operand_dims, operand_batching_dims, "scatter", + "scatter_dims_to_operand_dims", "operand_batching_dims") + + _no_duplicate_dims( + scatter_indices_batching_dims, "scatter", "scatter_indices_batching_dims" + ) + _dims_in_range( + scatter_indices_batching_dims, + _rank(indices), + "scatter", + "scatter_indices_batching_dims", + ) + if index_vector_dim in scatter_indices_batching_dims: + raise TypeError( + "Scatter op cannot have the index vector dimension as a batching " + f"dimension; got {scatter_indices_batching_dims}.") + + if len(operand_batching_dims) != len(scatter_indices_batching_dims): + raise TypeError( + "Scatter op requires equal numbers of operand_batching_dims and " + f"scatter_indices_batching_dims, got {operand_batching_dims} and " + f"{scatter_indices_batching_dims}." + ) + + operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) + indices_batch_shape = tuple( + indices.shape[i] for i in scatter_indices_batching_dims + ) + if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): + raise TypeError( + "Scatter op requires operand batching dimensions and indices batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{indices_batch_shape}." + ) + # Validate window_size - window_size = len(update_window_dims) + len(inserted_window_dims) + window_size = ( + len(update_window_dims) + + len(inserted_window_dims) + + len(operand_batching_dims) + ) if _rank(operand) != window_size: raise TypeError(f"Scatter op has window of size {window_size}; doesn't " f"match operand of rank {_rank(operand)}.") @@ -1933,8 +2077,14 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter", "scatter_dims_to_operand_dims") - max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape)) - if not i in set(inserted_window_dims)] + max_update_slice_sizes = [ + operand.shape[i] + for i in range(len(operand.shape)) + if ( + i not in set(inserted_window_dims) + and i not in set(operand_batching_dims) + ) + ] for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] @@ -1968,7 +2118,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in dnums.inserted_window_dims: + if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2029,13 +2179,19 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2067,13 +2223,19 @@ def _scatter_mul_transpose_rule(t, operand, indices, updates, *, raise NotImplementedError( "scatter_mul gradients are only implemented if `unique_indices=True`") gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2095,40 +2257,52 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) if ax is not None) operand = batching.bdim_at_front(operand, operand_bdim, size) - operand_bdim = 0 updates = batching.bdim_at_front(updates, updates_bdim, size) if indices_bdim is None: inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) + operand_batching_dims = tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + operand_batching_dims=operand_batching_dims, + scatter_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 - # see the third case in _gather_batching_rule for comparison and comments indices = batching.bdim_at_front(indices, indices_bdim, size) - count_shape = list(indices.shape) - count_shape[-1] = 1 - counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0) - indices = lax.concatenate([counts, indices], len(count_shape) - 1) - update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) - inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) - scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) + inserted_window_dims = tuple( + np.add(1, dimension_numbers.inserted_window_dims) + ) + operand_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) + scatter_indices_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.scatter_indices_batching_dims) + ) + scatter_dims_to_operand_dims = tuple( + np.add(1, dimension_numbers.scatter_dims_to_operand_dims) + ) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + operand_batching_dims=operand_batching_dims, + scatter_indices_batching_dims=scatter_indices_batching_dims, + ) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, @@ -2190,12 +2364,18 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, gather_dnums = GatherDimensionNumbers( offset_dims=scatter_dnums.update_window_dims, collapsed_slice_dims=scatter_dnums.inserted_window_dims, - start_index_map=scatter_dnums.scatter_dims_to_operand_dims) + start_index_map=scatter_dnums.scatter_dims_to_operand_dims, + operand_batching_dims=scatter_dnums.operand_batching_dims, + start_indices_batching_dims=scatter_dnums.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in scatter_dnums.inserted_window_dims: + if ( + i in scatter_dnums.inserted_window_dims + or i in scatter_dnums.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]]) @@ -2323,7 +2503,6 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # of using scatter-add here is that we don't need a `scatter` transpose # rule. - # a) attach a positive ID to each update in `updates`, and perform a scatter # on the IDs. ids_shape = list(updates.shape) @@ -2344,13 +2523,16 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # b) compute the inverse gather that "undoes" the scatter on the id values. gather_dnums = GatherDimensionNumbers( - offset_dims=dnums.update_window_dims, - collapsed_slice_dims=dnums.inserted_window_dims, - start_index_map=dnums.scatter_dims_to_operand_dims) + offset_dims=dnums.update_window_dims, + collapsed_slice_dims=dnums.inserted_window_dims, + start_index_map=dnums.scatter_dims_to_operand_dims, + operand_batching_dims=dnums.operand_batching_dims, + start_indices_batching_dims=dnums.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(scattered_ids.shape)): - if i in dnums.inserted_window_dims: + if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2405,13 +2587,19 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2479,8 +2667,8 @@ def _scatter_lower(ctx, operand, indices, updates, *, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=[], - scatter_indices_batching_dims=[], + input_batching_dims=list(dnums.operand_batching_dims), + scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) @@ -2539,8 +2727,8 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=[], - scatter_indices_batching_dims=[], + input_batching_dims=list(dnums.operand_batching_dims), + scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 559d17cd9514..ac3074f45934 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9824,6 +9824,8 @@ def replace(tup, val): offset_dims = [] start_index_map = [] collapsed_slice_dims = [] + operand_batching_dims = [] + start_indices_batching_dims = [] j = 0 for i in range(rank): if i == axis_int: @@ -9848,21 +9850,23 @@ def replace(tup, val): collapsed_slice_dims.append(i) j += 1 else: - # Otherwise, idx_shape[i] == arr_shape[i]. Use an iota index so - # corresponding elements of array and index are gathered. - # TODO(mattjj): next line needs updating for dynamic shapes - iota = lax.broadcasted_iota(index_dtype, gather_index_shape, j) - gather_indices.append(iota) - slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both + # array and index as batching so corresponding elements are gathered. + if core.definitely_equal(arr_shape[i], 0): + slice_sizes.append(0) + else: + slice_sizes.append(1) + operand_batching_dims.append(i) + start_indices_batching_dims.append(j) j += 1 gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), - start_index_map=tuple(start_index_map)) + start_index_map=tuple(start_index_map), + operand_batching_dims=tuple(operand_batching_dims), + start_indices_batching_dims=tuple(start_indices_batching_dims)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), mode="fill" if mode is None else mode, fill_value=fill_value) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 2bcfe96ad2f0..809df8195d54 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -122,7 +122,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, - scatter_dims_to_operand_dims=indexer.dnums.start_index_map + scatter_dims_to_operand_dims=indexer.dnums.start_index_map, + operand_batching_dims=indexer.dnums.operand_batching_dims, + scatter_indices_batching_dims=indexer.dnums.start_indices_batching_dims, ) out = scatter_op( x, indexer.gather_indices, y, dnums, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 6d7f2c2a1e2c..525b163a6140 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2870,6 +2870,9 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers): proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) + proto.operand_batching_dims.extend(dimension_numbers.operand_batching_dims) + proto.start_indices_batching_dims.extend( + dimension_numbers.start_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto @@ -2981,6 +2984,9 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers): proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) proto.scatter_dims_to_operand_dims.extend( dimension_numbers.scatter_dims_to_operand_dims) + proto.input_batching_dims.extend(dimension_numbers.operand_batching_dims) + proto.scatter_indices_batching_dims.extend( + dimension_numbers.scatter_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto diff --git a/tests/lax_test.py b/tests/lax_test.py index d82b35c6b711..a2d4c939df55 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2512,6 +2512,18 @@ def testIndexTake(self, shape, dtype, idxs, axes): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1)), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3)) ]], dtype=lax_test_util.all_dtypes, ) @@ -2529,63 +2541,196 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes): @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, "indices_shape": indices_shape, - "dimension_numbers": lax.GatherDimensionNumbers( - offset_dims=offset_dims, - collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map), + "dimension_numbers": dimension_numbers, "slice_sizes": slice_sizes, "msg": msg} - for (testcase_name, operand_shape, indices_shape, offset_dims, - collapsed_slice_dims, start_index_map, slice_sizes, msg) in [ + for (testcase_name, operand_shape, indices_shape, dimension_numbers, + slice_sizes, msg) in [ ("NonAscendingWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "offset_dims in gather op must be sorted"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 8, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "offset_dims in gather op must be sorted"), ("RepeatedWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "offset_dims in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "offset_dims in gather op must not repeat"), ("WindowIndexOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 100, 101, 102), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "Offset dimension 2 in gather op is out of bounds"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 100, 101, 102), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Offset dimension 2 in gather op is out of bounds"), ("WindowIndexBarelyOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "Offset dimension 4 in gather op is out of bounds"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 9), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Offset dimension 4 in gather op is out of bounds"), ("MismatchingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (4,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), ("All components of the offset index in a gather op must either be a " - "offset dimension or explicitly collapsed")), + "offset dimension or explicitly collapsed/batching")), + ("MismatchingElidedWindowDimsV2", (10, 9, 8, 7, 6, 5), (10, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), + start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (10, 9, 8, 7, 6, 5), + ("All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed/batching")), ("OutOfBoundsWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (0, 1, 2, 3, 19), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 19), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Invalid collapsed_slice_dims set in gather op; valid range is"), ("RepeatedWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (0, 1, 2, 3, 3), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "collapsed_slice_dims in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 3), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must not repeat"), ("MismatchingGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3)), + (10, 9, 8, 7, 6), ("Gather op has 4 elements in start_index_map and the bound of " "dimension index_vector_dim=4 of indices is 5. These two " "numbers must be equal.")), ("OutOfBoundsGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 7), (10, 9, 8, 7, 6), - "Invalid start_index_map"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 7)), + (10, 9, 8, 7, 6), "Invalid start_index_map"), ("RepeatedGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 3), (10, 9, 8, 7, 6), - "start_index_map in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 3)), + (10, 9, 8, 7, 6), "start_index_map in gather op must not repeat"), ("NonAscendingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (2, 1), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 1), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must be sorted"), ("WindowBoundsTooLarge", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (2,), (0, 1, 2, 3, 4), (10, 9, 8, 100, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(2,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 100, 6), "Slice size at index 3 in gather op is out of range"), ("MismatchingNumberOfWindowBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7), "Gather op must have one slice size for every input dimension"), ("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (1,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(1,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), ("Gather op can only collapse slice dims with bound 1, but bound " - "is 9 for index 1 at position 0.")) + "is 9 for index 1 at position 0.")), + ("RepeatedOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 4), operand_batching_dims=(2, 3, 3)), + (10, 9, 8, 7, 6), + "operand_batching_dims in gather op must not repeat"), + ("NonAscendingOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 4), operand_batching_dims=(3, 2)), + (10, 9, 8, 7, 6), + "operand_batching_dims in gather op must be sorted"), + ("OutOfBoundsOperandBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + operand_batching_dims=(0, 10)), + (10, 9, 8, 7, 6), + "Invalid operand_batching_dims set in gather op; valid range is"), + ("NonDisjointCollapsedAndBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1, 2), + start_index_map=(0, 1, 4), operand_batching_dims=(2, 3)), + (10, 9, 8, 7, 6), + ("collapsed_slice_dims and operand_batching_dims in gather op must be " + "disjoint")), + ("NonDisjointStartIndexMapAndBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 4), operand_batching_dims=(2, 3)), + (10, 9, 8, 7, 6), + ("start_index_map and operand_batching_dims in gather op must be " + "disjoint")), + ("WindowBoundsNot1ForBatchingDim", (10, 9, 8, 7, 6), (9, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), + start_index_map=(0, 2, 3, 4), operand_batching_dims=(1,), + start_indices_batching_dims=(0,)), + (10, 9, 8, 7, 6), + ("Gather op can only have operand batching dims with bound 0/1, but " + "bound is 9 for index 1 at position 0.")), + ("RepeatedStartIndicesBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 1, 0)), + (10, 9, 8, 7, 6), + "start_indices_batching_dims in gather op must not repeat"), + ("OutOfBoundsStartIndicesBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 5)), + (10, 9, 8, 7, 6), + "Invalid start_indices_batching_dims set in gather op; valid range"), + ("IndexVectorDimInStartIndicesBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 4)), + (10, 9, 8, 7, 6), + ("Gather op cannot have the index vector dimension as a batching " + "dimension")), + ("MismatchingNumberOfBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(1, 2), + start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), + start_indices_batching_dims=(0, 1)), + (10, 9, 8, 7, 6), + ("Gather op requires equal numbers of operand_batching_dims and " + "start_indices_batching_dims")), + ("MismatchingBatchingDimSizes", (10, 9, 8, 7, 6), (10, 9, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 3, 4), + start_index_map=(2, 3, 4), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (10, 9, 8, 7, 6), + ("Gather op requires operand batching dimensions and indices batching " + "dimensions to have the same shape")) ] ) def testGatherShapeCheckingRule(self, operand_shape, indices_shape, dimension_numbers, slice_sizes, msg): + """ + + Args: + operand_shape: + indices_shape: + dimension_numbers: + slice_sizes: + msg: + """ operand = np.ones(operand_shape, dtype=np.int32) indices = np.ones(indices_shape, dtype=np.int32) @@ -2602,9 +2747,19 @@ def testGatherShapeCheckingRule(self, operand_shape, indices_shape, ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.inexact_dtypes, mode=["clip", "fill", None], @@ -2628,9 +2783,19 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2653,9 +2818,19 @@ def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2677,9 +2852,19 @@ def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2701,9 +2886,19 @@ def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2721,84 +2916,207 @@ def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums): # variations to account for the implicit setting of index_vector_dim in JAX. @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, - "indices": indices, "update_shape": update_shape, - "dimension_numbers": lax.ScatterDimensionNumbers( - update_window_dims=update_window_dims, - inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims), + "indices_shape": indices_shape, "update_shape": update_shape, + "dimension_numbers": dimension_numbers, "msg": msg} - for (testcase_name, operand_shape, indices, update_shape, - update_window_dims, inserted_window_dims, - scatter_dims_to_operand_dims, msg) in [ - ("ScatterWithUpdatesBiggerThanInput", (64, 48), np.zeros((32, 1)), - (65, 32), (0,), (1,), (1,), "Bounds of the window dimensions"), - ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), - np.zeros((32, 1)), (32, 49), (1,), (0,), (1,), + for (testcase_name, operand_shape, indices_shape, update_shape, + dimension_numbers, msg) in [ + ("ScatterWithUpdatesBiggerThanInput", (64, 48), (32, 1), (65, 32), + lax.ScatterDimensionNumbers( + update_window_dims=(0,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the window dimensions"), - ("ScatterWithUpdatesNotMatchingIndices", (64, 48), - np.zeros((32, 1)), (64, 31), (0,), (1,), (1,), + ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), (32, 1), + (32, 49), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), + "Bounds of the window dimensions"), + ("ScatterWithUpdatesNotMatchingIndices", (64, 48), (32, 1), + (64, 31), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the scatter dimensions"), - ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), - np.zeros((32, 1)), (31, 48), (1,), (0,), (1,), + ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), (32, 1), + (31, 48), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the scatter dimensions"), ("ScatterNdWithUpdatesBiggerThanInput", (64, 48), - np.zeros((10, 9, 8, 7, 1)), (10, 9, 8, 7, 65), (4,), (1,), - (0,), "Bounds of the window dimensions"), + (10, 9, 8, 7, 1), (10, 9, 8, 7, 65), + lax.ScatterDimensionNumbers( + update_window_dims=(4,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,)), + "Bounds of the window dimensions"), ("ScatterNdWithUpdatesNotMatchingIndices", (64, 48), - np.zeros((10, 9, 8, 7, 1)), (9, 9, 8, 7, 64), (4,), (1,), (0,), + (10, 9, 8, 7, 1), (9, 9, 8, 7, 64), + lax.ScatterDimensionNumbers( + update_window_dims=(4,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(0,)), "Bounds of the scatter dimensions"), - ("InvalidUpdates", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4, 1), - (4, 5, 6), (1, 2), (0, 1, 2, 3, 4), + ("InvalidUpdates", (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4, 1), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Updates tensor must be of rank 7; got 8."), - ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), + ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 8, 7), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "update_window_dims in scatter op must be sorted"), - ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), + ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 7, 7), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "update_window_dims in scatter op must not repeat"), - ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), + ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 7, 9), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Invalid update_window_dims set in scatter op"), ("NonAscendingInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (2, 1), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(2, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "inserted_window_dims in scatter op must be sorted"), ("RepeatedInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 1), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "inserted_window_dims in scatter op must not repeat"), ("OutOfBoundsInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 5), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 5), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Invalid inserted_window_dims set in scatter op"), ("MismatchingScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 3), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3)), ("Scatter op has 4 elements in scatter_dims_to_operand_dims and " "the bound of dimension index_vector_dim=4 of indices " "is 5. These two numbers must be equal")), ("OutOfBoundsScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 3, 10), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 10)), "Invalid scatter_dims_to_operand_dims mapping"), ("RepeatedValuesInScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 2, 3), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 2, 3)), "scatter_dims_to_operand_dims in scatter op must not repeat"), ("InsufficientWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1,), (0, 1, 2, 3), + (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(0, 1, 2, 3)), ("Scatter op has window of size 4; doesn't match operand of " - "rank 5.")) + "rank 5.")), + ("InsufficientWindowDimsV2", (10, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 3), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1, 2, 3), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,)), + ("Scatter op has window of size 5; doesn't match operand of " + "rank 6.")), + ("RepeatedOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(2, 3, 3)), + "operand_batching_dims in scatter op must not repeat"), + ("NonAscendingOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(3, 2)), + "operand_batching_dims in scatter op must be sorted"), + ("OutOfBoundsOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + operand_batching_dims=(0, 10)), + ("Invalid operand_batching_dims set in scatter op; valid range " + "is")), + ("NonDisjointCollapsedAndBatchingDims", (50, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(1, 2)), + ("inserted_window_dims and operand_batching_dims in scatter op " + "must be disjoint")), + ("NonDisjointScatterDimsToOperandDimsAndBatchingDims", + (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 4), + operand_batching_dims=(2, 3)), + ("scatter_dims_to_operand_dims and operand_batching_dims in " + "scatter op must be disjoint")), + ("RepeatedScatterIndicesBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 1, 0)), + "scatter_indices_batching_dims in scatter op must not repeat"), + ("OutOfBoundsScatterIndicesBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 5)), + ("Invalid scatter_indices_batching_dims set in scatter op; " + "valid range")), + ("IndexVectorDimInScatterIndicesBatchingDims", + (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 4)), + ("Scatter op cannot have the index vector dimension as a " + "batching dimension")), + ("MismatchingNumberOfBatchingDims", (50, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(1, 2, 3, 4), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0, 1)), + ("Scatter op requires equal numbers of operand_batching_dims " + "and scatter_indices_batching_dims")), + ("MismatchingBatchingDimSizes", (10, 9, 48, 47, 46, 45), + (10, 9, 8, 7, 2), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(2,), + scatter_dims_to_operand_dims=(2, 3), + operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0)), + ("Scatter op requires operand batching dimensions and indices " + "batching dimensions to have the same shape")) ] ) - def testScatterShapeCheckingRule(self, operand_shape, indices, + def testScatterShapeCheckingRule(self, operand_shape, indices_shape, update_shape, dimension_numbers, msg): - + indices = np.zeros(indices_shape, dtype=np.int32) def f(x, y): operand = lax.broadcast(x, operand_shape) updates = lax.broadcast(y, update_shape) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37a0011e7bd0..0f259bf490e6 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -566,6 +566,18 @@ def testFft(self, fft_ndims, shape, bdims): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1)), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3)) ] for bdims in lax_test_util.all_bdims(shape, idxs.shape)], dtype=lax_test_util.all_dtypes @@ -590,6 +602,16 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape, update_shape)], dtype=lax_test_util.float_dtypes @@ -613,6 +635,16 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)], dtype=lax_test_util.float_dtypes,