File tree 2 files changed +12
-4
lines changed
2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change 17
17
18
18
import jax .numpy as jnp
19
19
20
+ import scico .numpy as snp
21
+
20
22
from ._blockarray import BlockArray
21
23
22
24
@@ -83,9 +85,7 @@ def mapped(*args, **kwargs):
83
85
84
86
map_arg_val = bound_args .arguments .pop (map_arg_name )
85
87
86
- if not isinstance (map_arg_val , tuple ) or not all (
87
- isinstance (x , tuple ) for x in map_arg_val
88
- ): # not nested tuple
88
+ if not snp .util .is_nested (map_arg_val ): # not nested tuple
89
89
return func (* args , ** kwargs ) # no mapping
90
90
91
91
# map
Original file line number Diff line number Diff line change @@ -248,9 +248,17 @@ def test_ufunc_conj():
248
248
def test_create_zeros ():
249
249
A = snp .zeros (2 )
250
250
assert np .all (A == 0 )
251
+ assert isinstance (A , jax .Array )
252
+
253
+ A = snp .zeros ((2 ,))
254
+ assert isinstance (A , jax .Array )
251
255
252
256
A = snp .zeros (((2 ,), (2 ,)))
253
257
assert all (snp .all (A == 0 ))
258
+ assert isinstance (A , snp .BlockArray )
259
+
260
+ A = snp .zeros (())
261
+ assert isinstance (A , jax .Array ) # from issue 499
254
262
255
263
256
264
def test_create_ones ():
@@ -261,7 +269,7 @@ def test_create_ones():
261
269
assert all (snp .all (A == 1 ))
262
270
263
271
264
- def test_create_zeros ():
272
+ def test_create_empty ():
265
273
A = snp .empty (2 )
266
274
assert np .all (A == 0 )
267
275
You can’t perform that action at this time.
0 commit comments