Skip to content

Commit eebda28

Browse files
Correctly check for nested tuple in map_func_over_tuple_of_tuples
1 parent 13ef0a8 commit eebda28

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

scico/numpy/_wrappers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import jax.numpy as jnp
1919

20+
import scico.numpy as snp
21+
2022
from ._blockarray import BlockArray
2123

2224

@@ -83,9 +85,7 @@ def mapped(*args, **kwargs):
8385

8486
map_arg_val = bound_args.arguments.pop(map_arg_name)
8587

86-
if not isinstance(map_arg_val, tuple) or not all(
87-
isinstance(x, tuple) for x in map_arg_val
88-
): # not nested tuple
88+
if not snp.util.is_nested(map_arg_val): # not nested tuple
8989
return func(*args, **kwargs) # no mapping
9090

9191
# map

scico/test/numpy/test_numpy.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,17 @@ def test_ufunc_conj():
248248
def test_create_zeros():
249249
A = snp.zeros(2)
250250
assert np.all(A == 0)
251+
assert isinstance(A, jax.Array)
252+
253+
A = snp.zeros((2,))
254+
assert isinstance(A, jax.Array)
251255

252256
A = snp.zeros(((2,), (2,)))
253257
assert all(snp.all(A == 0))
258+
assert isinstance(A, snp.BlockArray)
259+
260+
A = snp.zeros(())
261+
assert isinstance(A, jax.Array) # from issue 499
254262

255263

256264
def test_create_ones():
@@ -261,7 +269,7 @@ def test_create_ones():
261269
assert all(snp.all(A == 1))
262270

263271

264-
def test_create_zeros():
272+
def test_create_empty():
265273
A = snp.empty(2)
266274
assert np.all(A == 0)
267275

0 commit comments

Comments
 (0)