Skip to content

Commit

Permalink
test Scale through func property
Browse files Browse the repository at this point in the history
  • Loading branch information
nhatnm52 committed Mar 14, 2024
1 parent 111d70a commit eacdca7
Showing 1 changed file with 55 additions and 18 deletions.
73 changes: 55 additions & 18 deletions tests/test_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
class TestScaler:
@pytest.fixture(
params=(
(1, 2, 1, 256, 256),
(3, 512, 512),
(256, 256),
[(1, 2, 1, 256, 256), True],
[(1, 2, 1, 256, 256), False],
[(3, 512, 512), True],
[(3, 512, 512), False],
[(256, 256), True],
[(256, 256), False],
),
ids=["5D", "3D", "2D"],
ids=["5D-directly", "5D-indirectly",
"3D-directly", "3D-indirectly",
"2D-directly", "2D-indirectly"
],
)
def shape(self, request):
def test_case(self, request):
return request.param

def create_data(self, shape, dtype=np.uint8, mean_val=10):
Expand All @@ -30,42 +36,68 @@ def check_downscaled(self, downscaled, shape, scale_factor=2):
sh // scale_factor for sh in expected_shape[-2:]
)

def test_nearest(self, shape):
def test_nearest(self, test_case):
shape, directly = test_case
data = self.create_data(shape)
scaler = Scaler()
downscaled = scaler.nearest(data)
if directly:
downscaled = scaler.nearest(data)
else:
scaler.method = "nearest"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

# this fails because of wrong channel dimension; need to fix in follow-up PR
@pytest.mark.xfail
def test_gaussian(self, shape):
def test_gaussian(self, test_case):
shape, directly = test_case
data = self.create_data(shape)
scaler = Scaler()
downscaled = scaler.gaussian(data)
if directly:
downscaled = scaler.gaussian(data)
else:
scaler.method = "gaussian"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

# this fails because of wrong channel dimension; need to fix in follow-up PR
@pytest.mark.xfail
def test_laplacian(self, shape):
def test_laplacian(self, test_case):
shape, directly = test_case
data = self.create_data(shape)
scaler = Scaler()
downscaled = scaler.laplacian(data)
if directly:
downscaled = scaler.laplacian(data)
else:
scaler.method = "laplacian"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

def test_local_mean(self, shape):
def test_local_mean(self, test_case):
shape, directly = test_case
data = self.create_data(shape)
scaler = Scaler()
downscaled = scaler.local_mean(data)
if directly:
downscaled = scaler.local_mean(data)
else:
scaler.method = "local_mean"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

@pytest.mark.skip(reason="This test does not terminate")
def test_zoom(self, shape):
def test_zoom(self, test_case):
shape, directly = test_case
data = self.create_data(shape)
scaler = Scaler()
downscaled = scaler.zoom(data)
if directly:
downscaled = scaler.zoom(data)
else:
scaler.method = "zoom"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

def test_scale_dask(self, shape):
def test_scale_dask(self, test_case):
shape, directly = test_case
data = self.create_data(shape)
# chunk size gives odd-shaped chunks at the edges
# tests https://github.com/ome/ome-zarr-py/pull/244
Expand All @@ -75,8 +107,13 @@ def test_scale_dask(self, shape):
data_delayed = da.from_array(data, chunks=chunk_2d)

scaler = Scaler()
resized_data = scaler.resize_image(data)
resized_dask = scaler.resize_image(data_delayed)
if directly:
resized_data = scaler.resize_image(data)
resized_dask = scaler.resize_image(data_delayed)
else:
scaler.method = "resize_image"
resized_data = scaler.func(data)
resized_dask = scaler.func(data_delayed)

assert np.array_equal(resized_data, resized_dask)

Expand Down

0 comments on commit eacdca7

Please sign in to comment.