Skip to content

Commit

Permalink
Add option to use jax.image.resize in Interpolation decorator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646993039
  • Loading branch information
aleximmer authored and copybara-github committed Jun 26, 2024
1 parent 56b2e30 commit d0290d3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
28 changes: 19 additions & 9 deletions connectomics/volume/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,22 +788,26 @@ class Interpolation(Decorator):

def __init__(self,
size: Sequence[int],
use_jax: bool = False,
backend: str = 'scipy_map_coordinates',
context_spec: Optional[MutableJsonSpec] = None,
**interpolation_args):
"""Interpolation.
Args:
size: New size of TensorStore.
use_jax: Whether to use scipy or jax version of `map_coordinates`.
Note that the jax version may only implement a subset of features, see
`jax.scipy.ndimage.map_coordinates` docs for current status.
backend: Backend to use for interpolation. One of 'scipy_map_coordinates',
'jax_map_coordinates', or 'jax_resize'. Defaults to the first.
context_spec: Spec for virtual chunked context overriding its defaults.
**interpolation_args: Passed to `scipy.ndimage.map_coordinates`.
**interpolation_args: Passed to `scipy.ndimage.map_coordinates`,
`jax.scipy.ndimage.map_coordinates` , or`jax.image.resize` depending
on `backend`, respectively.
"""
super().__init__(context_spec)
self._size = size
self._use_jax = use_jax
backends = ('scipy_map_coordinates', 'jax_map_coordinates', 'jax_resize')
if backend not in backends:
raise ValueError(f'Unsupported backend: {backend} not in {backends}.')
self._backend = backend
self._interpolation_args = interpolation_args

def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore:
Expand All @@ -820,7 +824,7 @@ def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore:
f'currently supported, but `inclusive_min` is: {inclusive_min}.')

resize_dim = [d for d, s in enumerate(self._size) if s != input_ts.shape[d]]
map_coordinates = (scipy.ndimage.map_coordinates if not self._use_jax
map_coordinates = (scipy.ndimage.map_coordinates if 'scipy' in self._backend
else jax.scipy.ndimage.map_coordinates)

def read_fn(domain: ts.IndexDomain, array: np.ndarray,
Expand All @@ -842,8 +846,14 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray,
else:
slices.append(slice(0, data.shape[d]))

array[...] = map_coordinates(data, np.mgrid[slices],
**self._interpolation_args)
if self._backend == 'jax_resize':
sub_size = [s if d in resize_dim else data.shape[d]
for d, s in enumerate(self._size)]
array[...] = jax.image.resize(data, sub_size,
**self._interpolation_args)
else:
array[...] = map_coordinates(data, np.mgrid[slices],
**self._interpolation_args)

json = input_ts.schema.to_json()
json['domain']['exclusive_max'] = self._size
Expand Down
21 changes: 11 additions & 10 deletions connectomics/volume/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,19 +340,20 @@ def test_interpolation(self):
}).result()
data_ts[...] = data

for use_jax in (True, False):
expected_res = np.array([
[[0., 0.],
[0., 0.]],
[[10.//2, 20.//2],
[30.//2, 40.//2]],
[[10., 20.],
[30., 40.]],])
backends = ('scipy_map_coordinates', 'jax_map_coordinates', 'jax_resize')
for backend in backends:
kwargs = {'method': 'linear'} if backend == 'jax_resize' else {'order': 1}
dec = decorators.Interpolation(
size=(3, 2, 2), order=1, use_jax=use_jax,
size=(3, 2, 2), backend=backend, **kwargs
)
vc = dec.decorate(data_ts)

expected_res = np.array([
[[0., 0.],
[0., 0.]],
[[10.//2, 20.//2],
[30.//2, 40.//2]],
[[10., 20.],
[30., 40.]],])
np.testing.assert_equal(vc[...].read().result(), expected_res)

def test_multiply(self):
Expand Down

0 comments on commit d0290d3

Please sign in to comment.