From d0290d3799afd4f118435c178c76a6fe6af86529 Mon Sep 17 00:00:00 2001 From: Alex Immer Date: Wed, 26 Jun 2024 10:19:48 -0700 Subject: [PATCH] Add option to use `jax.image.resize` in Interpolation decorator. PiperOrigin-RevId: 646993039 --- connectomics/volume/decorators.py | 28 +++++++++++++++++--------- connectomics/volume/decorators_test.py | 21 ++++++++++--------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/connectomics/volume/decorators.py b/connectomics/volume/decorators.py index 5c70cf3..ed17cea 100644 --- a/connectomics/volume/decorators.py +++ b/connectomics/volume/decorators.py @@ -788,22 +788,26 @@ class Interpolation(Decorator): def __init__(self, size: Sequence[int], - use_jax: bool = False, + backend: str = 'scipy_map_coordinates', context_spec: Optional[MutableJsonSpec] = None, **interpolation_args): """Interpolation. Args: size: New size of TensorStore. - use_jax: Whether to use scipy or jax version of `map_coordinates`. - Note that the jax version may only implement a subset of features, see - `jax.scipy.ndimage.map_coordinates` docs for current status. + backend: Backend to use for interpolation. One of 'scipy_map_coordinates', + 'jax_map_coordinates', or 'jax_resize'. Defaults to the first. context_spec: Spec for virtual chunked context overriding its defaults. - **interpolation_args: Passed to `scipy.ndimage.map_coordinates`. + **interpolation_args: Passed to `scipy.ndimage.map_coordinates`, + `jax.scipy.ndimage.map_coordinates` , or`jax.image.resize` depending + on `backend`, respectively. """ super().__init__(context_spec) self._size = size - self._use_jax = use_jax + backends = ('scipy_map_coordinates', 'jax_map_coordinates', 'jax_resize') + if backend not in backends: + raise ValueError(f'Unsupported backend: {backend} not in {backends}.') + self._backend = backend self._interpolation_args = interpolation_args def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore: @@ -820,7 +824,7 @@ def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore: f'currently supported, but `inclusive_min` is: {inclusive_min}.') resize_dim = [d for d, s in enumerate(self._size) if s != input_ts.shape[d]] - map_coordinates = (scipy.ndimage.map_coordinates if not self._use_jax + map_coordinates = (scipy.ndimage.map_coordinates if 'scipy' in self._backend else jax.scipy.ndimage.map_coordinates) def read_fn(domain: ts.IndexDomain, array: np.ndarray, @@ -842,8 +846,14 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray, else: slices.append(slice(0, data.shape[d])) - array[...] = map_coordinates(data, np.mgrid[slices], - **self._interpolation_args) + if self._backend == 'jax_resize': + sub_size = [s if d in resize_dim else data.shape[d] + for d, s in enumerate(self._size)] + array[...] = jax.image.resize(data, sub_size, + **self._interpolation_args) + else: + array[...] = map_coordinates(data, np.mgrid[slices], + **self._interpolation_args) json = input_ts.schema.to_json() json['domain']['exclusive_max'] = self._size diff --git a/connectomics/volume/decorators_test.py b/connectomics/volume/decorators_test.py index 485a59b..62358d5 100644 --- a/connectomics/volume/decorators_test.py +++ b/connectomics/volume/decorators_test.py @@ -340,19 +340,20 @@ def test_interpolation(self): }).result() data_ts[...] = data - for use_jax in (True, False): + expected_res = np.array([ + [[0., 0.], + [0., 0.]], + [[10.//2, 20.//2], + [30.//2, 40.//2]], + [[10., 20.], + [30., 40.]],]) + backends = ('scipy_map_coordinates', 'jax_map_coordinates', 'jax_resize') + for backend in backends: + kwargs = {'method': 'linear'} if backend == 'jax_resize' else {'order': 1} dec = decorators.Interpolation( - size=(3, 2, 2), order=1, use_jax=use_jax, + size=(3, 2, 2), backend=backend, **kwargs ) vc = dec.decorate(data_ts) - - expected_res = np.array([ - [[0., 0.], - [0., 0.]], - [[10.//2, 20.//2], - [30.//2, 40.//2]], - [[10., 20.], - [30., 40.]],]) np.testing.assert_equal(vc[...].read().result(), expected_res) def test_multiply(self):