Skip to content

Commit

Permalink
Update _lax_keep_unit_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 14, 2025
1 parent 0654e0f commit 87ffd83
Showing 1 changed file with 87 additions and 7 deletions.
94 changes: 87 additions & 7 deletions brainunit/lax/_lax_keep_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,23 @@ def test_dynamic_update_slice(self, shape, indices, update_shape):
start_index_map=(2,), operand_batching_dims=(0, 1),
start_indices_batching_dims=(1, 0)),
(1, 1, 3))
]] if sys.version_info >= (3, 10) else [
dict(shape=shape, idxs=idxs, dnums=dnums, slice_sizes=slice_sizes)
for shape, idxs, dnums, slice_sizes in [
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]],
)
@unittest.skipIf(sys.version_info < (3, 10), "JAX now do not support the python version below 3.10")
def test_gather(self, shape, idxs, dnums, slice_sizes):
rand_idxs = bst.random.randint(0., high=max(shape), size=idxs.shape)
array = bst.random.random(shape)
Expand Down Expand Up @@ -338,7 +352,21 @@ def test_lax_keep_unit_math_binary(self, value, unit):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down Expand Up @@ -380,9 +408,23 @@ def test_scatter(self, arg_shape, idxs, update_shape, dnums):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]],
mode=["clip", "fill", None],
op=['scatter_add', 'scatter_sub'],
op=['scatter_add', 'scatter_sub'] if sys.version_info >= (3, 10) else ['scatter_add'],
)
def test_scatter_add_sub(self, arg_shape, idxs, update_shape, dnums, mode, op):
ulax_op = getattr(ulax, op)
Expand Down Expand Up @@ -431,7 +473,19 @@ def test_scatter_mul(self):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter_min(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down Expand Up @@ -473,7 +527,20 @@ def test_scatter_min(self, arg_shape, idxs, update_shape, dnums):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter_max(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down Expand Up @@ -515,7 +582,20 @@ def test_scatter_max(self, arg_shape, idxs, update_shape, dnums):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter_apply(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down

0 comments on commit 87ffd83

Please sign in to comment.