Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 14, 2025
1 parent 3103d18 commit 0654e0f
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 58 deletions.
2 changes: 1 addition & 1 deletion brainunit/lax/_lax_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ def test_lax_collapse(self, value):
result = bulax_fun(q1, value2)

with pytest.raises(bu.UnitMismatchError):
result = bulax_fun(q1, value2, unit_to_scale=bu.second)
result = bulax_fun(q1, value2, unit_to_scale=bu.second)
9 changes: 4 additions & 5 deletions brainunit/lax/_lax_array_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@
# ==============================================================================


import jax.numpy as jnp
import jax.lax as lax
import pytest
import jax.numpy as jnp
from absl.testing import parameterized

import brainunit as bu
import brainunit.lax as bulax
from brainunit import meter, second
from brainunit._base import assert_quantity

lax_array_creation_given_array = [
'zeros_like_array',
]
]

lax_array_creation_misc = [
'iota', 'broadcasted_iota',
]


class TestLaxArrayCreation(parameterized.TestCase):
def __init__(self, *args, **kwargs):
super(TestLaxArrayCreation, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -95,4 +94,4 @@ def test_lax_array_creation_broadcasted_iota(self, shape, unit):

result = bulax_fun(float, shape, dimension, unit=unit)
expected = lax_fun(float, shape, dimension)
assert_quantity(result, expected, unit=unit)
assert_quantity(result, expected, unit=unit)
24 changes: 12 additions & 12 deletions brainunit/lax/_lax_change_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ def test_lax_change_unit_conv(self, shape, window_strides, padding):
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))

@parameterized.product(
shapes = [
shapes=[
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for lhs_shape, rhs_shape in [
((b, 10, i), (k, i, j))
for b, i, j, k in itertools.product(
[2, 3], [2, 3], [2, 3], [3,]
)
]
((b, 10, i), (k, i, j))
for b, i, j, k in itertools.product(
[2, 3], [2, 3], [2, 3], [3, ]
)
]
],
strides=[(1,), (2,)],
padding=["VALID", "SAME"],
Expand Down Expand Up @@ -251,17 +251,17 @@ def test_lax_change_unit_conv_transpose(self, shapes, strides, padding):
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))

@parameterized.product(
shapes = [
shapes=[
dict(
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
)
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
],
)
def test_lax_change_unit_dot_general(self, shapes):
Expand All @@ -286,4 +286,4 @@ def test_lax_change_unit_dot_general(self, shapes):
q2 = rhs * meter
result = bulax_fun(q1, q2, dimension_numbers=dimension_numbers)
expected = lax_fun(jnp.array(lhs), jnp.array(rhs), dimension_numbers=dimension_numbers)
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))
6 changes: 4 additions & 2 deletions brainunit/lax/_lax_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,13 +739,15 @@ def _fun_lax_scatter(
mode
) -> Union[Quantity, jax.Array]:
if isinstance(operand, Quantity) and isinstance(updates, Quantity):
assert has_same_unit(operand, updates), f'operand(unit:{operand.unit}) and updates(unit:{updates.unit}) do not have same unit'
assert has_same_unit(operand,
updates), f'operand(unit:{operand.unit}) and updates(unit:{updates.unit}) do not have same unit'
return maybe_decimal(Quantity(fun(operand.mantissa, scatter_indices, updates.mantissa, dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode), unit=operand.unit))
elif isinstance(operand, Quantity) or isinstance(updates, Quantity):
raise AssertionError(f'operand and updates should both be `Quantity` or Array, now we got {type(operand)} and {type(updates)}')
raise AssertionError(
f'operand and updates should both be `Quantity` or Array, now we got {type(operand)} and {type(updates)}')
else:
return fun(operand, scatter_indices, updates, dimension_numbers,
indices_are_sorted=indices_are_sorted,
Expand Down
33 changes: 13 additions & 20 deletions brainunit/lax/_lax_keep_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,19 @@
# limitations under the License.
# ==============================================================================
import sys
import unittest

import jax.lax
import jax.numpy as jnp
import brainstate as bst
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import pytest
from absl.testing import parameterized
import brainstate as bst
from jax._src import test_util as jtu
import unittest

import brainunit as u
import brainunit.lax as ulax
from brainunit import meter, second
from brainunit._base import assert_quantity
from brainunit.lax import gather

lax_array_manipulation = [
'slice', 'dynamic_slice', 'dynamic_update_slice', 'gather',
Expand Down Expand Up @@ -137,7 +134,7 @@ def test_dynamic_update_slice(self, shape, indices, update_shape):
result_q = ulax.dynamic_update_slice(array, start_indices=start_indices, update=update)
assert_quantity(result_q, expected, u.second)

@unittest.skipIf(sys.version_info < (3, 10), "JAX now do not support the python version below 3.10")

@parameterized.product(
[dict(shape=shape, idxs=idxs, dnums=dnums, slice_sizes=slice_sizes)
for shape, idxs, dnums, slice_sizes in [
Expand Down Expand Up @@ -167,6 +164,7 @@ def test_dynamic_update_slice(self, shape, indices, update_shape):
(1, 1, 3))
]],
)
@unittest.skipIf(sys.version_info < (3, 10), "JAX now do not support the python version below 3.10")
def test_gather(self, shape, idxs, dnums, slice_sizes):
rand_idxs = bst.random.randint(0., high=max(shape), size=idxs.shape)
array = bst.random.random(shape)
Expand Down Expand Up @@ -216,7 +214,6 @@ def test_slice_in_dim(self):
result_q = ulax.slice_in_dim(array, start_index=start_index, limit_index=limit_index)
assert_quantity(result_q, expected, u.second)


def test_index_in_dim(self):
# TODO: No test in JAX
...
Expand All @@ -232,7 +229,7 @@ def test_dynamic_index_in_dim(self):
def test_dynamic_update_slice_in_dim(self):
x = jnp.ones((6, 7), jnp.int32)
with self.assertRaises(TypeError):
ulax.dynamic_update_slice_in_dim(x, jnp.ones((2,7), jnp.int32),
ulax.dynamic_update_slice_in_dim(x, jnp.ones((2, 7), jnp.int32),
jnp.array([2, 2]), axis=0)

def test_dynamic_update_index_in_dim(self):
Expand Down Expand Up @@ -360,7 +357,6 @@ def test_scatter(self, arg_shape, idxs, update_shape, dnums):

assert_quantity(result_q, expected, u.second)


@parameterized.product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
Expand Down Expand Up @@ -454,7 +450,6 @@ def test_scatter_min(self, arg_shape, idxs, update_shape, dnums):

assert_quantity(result_q, expected, u.second)


@parameterized.product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
Expand Down Expand Up @@ -497,7 +492,6 @@ def test_scatter_max(self, arg_shape, idxs, update_shape, dnums):

assert_quantity(result_q, expected, u.second)


@parameterized.product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
Expand Down Expand Up @@ -536,7 +530,6 @@ def test_scatter_apply(self, arg_shape, idxs, update_shape, dnums):

assert_quantity(result_q, expected, u.second)


@parameterized.product(
[dict(shape=shape, pads=pads) for shape, pads in [
((0, 2), [(1, 2, 1), (0, 1, 0)]),
Expand All @@ -552,7 +545,7 @@ def test_scatter_apply(self, arg_shape, idxs, update_shape, dnums):
((5,), [(-1, -2, 2), ]),
((4, 2), [(-1, -2, 1), (1, 2, 2)])
]
],
],
)
def test_pad(self, shape, pads):
array = bst.random.random(shape)
Expand All @@ -575,10 +568,10 @@ class TestLaxKeepUnitNary(parameterized.TestCase):
@parameterized.product(
[dict(min_shape=min_shape, operand_shape=operand_shape, max_shape=max_shape)
for min_shape, operand_shape, max_shape in [
[(), (2, 3), ()],
[(2, 3), (2, 3), ()],
[(), (2, 3), (2, 3)],
[(2, 3), (2, 3), (2, 3)],
[(), (2, 3), ()],
[(2, 3), (2, 3), ()],
[(), (2, 3), (2, 3)],
[(2, 3), (2, 3), (2, 3)],
]],
)
def test_clamp(self, min_shape, operand_shape, max_shape):
Expand All @@ -599,7 +592,6 @@ def test_clamp(self, min_shape, operand_shape, max_shape):
assert_quantity(result_q, expected, u.second)



class TestLaxTypeConversion(parameterized.TestCase):

@parameterized.product(
Expand All @@ -618,11 +610,11 @@ def test_convert_element_type(self, input_type, dtype, value):
result_q = ulax_op(input_type(value) * u.second)
assert_quantity(result_q, expected, u.second)


def test_bitcast_convert_type(self):
# TODO: dtypes.bit_width need the source code of JAX
...


def compute_recall(result_neighbors, ground_truth_neighbors) -> float:
"""Computes the recall of an approximate nearest neighbor search.
Expand All @@ -648,6 +640,7 @@ def compute_recall(result_neighbors, ground_truth_neighbors) -> float:
for q, nn_per_q in enumerate(result_neighbors))
return hits / ground_truth_neighbors.size


class TestLaxKeepUnitReturnQuantityIndex(parameterized.TestCase):

@parameterized.product(
Expand Down
2 changes: 1 addition & 1 deletion brainunit/lax/_lax_linalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Union, Callable, Any, Tuple, List
from typing import Union, Callable, Any

import jax
from jax import lax, Array
Expand Down
9 changes: 1 addition & 8 deletions brainunit/lax/_lax_linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@
# limitations under the License.
# ==============================================================================
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from jax import lax
import brainstate as bst

import brainunit as u
import brainunit.lax as ulax
from brainunit import second, meter, ms
from brainunit._base import assert_quantity


lax_linear_algebra_change_unit_unary = [
'cholesky',
]
Expand Down Expand Up @@ -60,6 +56,7 @@
'tridiagonal_solve',
]


class TestLaxLinalg(parameterized.TestCase):

def __init__(self, *args, **kwargs):
Expand All @@ -83,7 +80,6 @@ def test_eig(self):
assert_quantity(vl, vl_e)
assert_quantity(vr, vr_e)


def test_cholesky(self):
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])

Expand Down Expand Up @@ -138,7 +134,6 @@ def test_qr(self):
assert_quantity(q, q_e)
assert_quantity(r, r_e, u.second)


def test_lu(self):
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])

Expand Down Expand Up @@ -200,7 +195,6 @@ def test_svd(self):
assert_quantity(s, s_e, u.second)
assert_quantity(vh, vh_e)


def test_tridiagonal(self):
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])

Expand Down Expand Up @@ -259,4 +253,3 @@ def test_tridiagonal_solve(self):
b = b * u.second
result_q = ulax.tridiagonal_solve(dl, d, du, b)
assert_quantity(result_q, expected, u.second)

2 changes: 1 addition & 1 deletion brainunit/lax/_lax_remove_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .._base import Quantity
from .._misc import set_module_as
from ..math._fun_remove_unit import _fun_remove_unit_unary, _fun_logic_unary, _fun_logic_binary
from ..math._fun_remove_unit import _fun_remove_unit_unary, _fun_logic_binary

__all__ = [
# math funcs remove unit (unary)
Expand Down
6 changes: 3 additions & 3 deletions brainunit/lax/_lax_remove_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# ==============================================================================


import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

import brainunit as bu
import brainunit.lax as bulax
from brainunit import meter, second
from brainunit._base import assert_quantity
Expand All @@ -32,6 +31,7 @@
'eq', 'ne', 'ge', 'gt', 'le', 'lt',
]


class TestLaxRemoveUnit(parameterized.TestCase):
def __init__(self, *args, **kwargs):
super(TestLaxRemoveUnit, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -85,4 +85,4 @@ def test_lax_remove_unit_logic_binary(self, value, unit):
result = bulax_fun(jnp.array(x1), q2)

with pytest.raises(AssertionError):
result = bulax_fun(q1, jnp.array(x2))
result = bulax_fun(q1, jnp.array(x2))
8 changes: 3 additions & 5 deletions brainunit/lax/_misc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
# ==============================================================================


import jax.numpy as jnp
import jax.lax as lax
import pytest
import jax.numpy as jnp
from absl.testing import parameterized

import brainunit as u
import brainunit.lax as ulax
from brainunit import meter
from brainunit._base import assert_quantity

lax_misc = [
Expand All @@ -31,6 +28,7 @@
'broadcast_shapes',
]


class TestLaxMisc(parameterized.TestCase):
# def test_after_all(self):
# token1 = lax.create_token()
Expand Down Expand Up @@ -66,4 +64,4 @@ def test_broadcast_shapes(self):
expecteds = lax.broadcast_shapes(shape1, shape2)

for result, expected in zip(results, expecteds):
self.assertTrue(result == expected)
self.assertTrue(result == expected)

0 comments on commit 0654e0f

Please sign in to comment.