Skip to content

Commit

Permalink
more efficient separable convolution (#355)
Browse files Browse the repository at this point in the history
closes #321 

This PR adds a different implementation of separable convolution that was adapted from `opencv_contrib`. It is not yet ready for review.

Still needs:
- [x] support other boundary modes
- [x] tests
- [x] ~extension to nD~  extension to 3D
- [x] wrappers for existing filters to call this one instead of CuPy's

The key changes in this PR are the new files:
`skimage/filters/_separable_filtering.py`
`skimage/filters/tests/test_separable_filtering.py`

It is based on the approach taken in OpenCV-contrib's [row](https://github.com/opencv/opencv_contrib/blob/4.6.0/modules/cudafilters/src/cuda/row_filter.hpp) and [column]([https://github.com/opencv/opencv_contrib/blob/4.6.0/modules/cudafilters/src/cuda/column_filter.hpp] ) filters, but also supports:
- 3D
- additional dtypes (e.g. complex64)
- all boundary modes from SciPy
- not restricted to kernel size <= 32
- casting behavior to the output matches SciPy rather than OpenCV conventions.
 
A simpler version of the same approach has long been in the CUDA samples [convolutionSeparable.cu example](https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/convolutionSeparable/convolutionSeparable.cu). The basic idea is:
1.) First stage loads the current patch of the image and its boundaries into shared memory
2.) After synchronization, convolution is performed on the shared memory array.

A lot of CuPy's ndimage code is vendored here for the following reasons:
- We need any `cupyx.scipy.ndimage` functions we call, like `gaussian_filter`, `uniform_filter`, etc. to dispatch to the new convolution implementation when possible.
- the `_get_weights_dtype` utility was changed to promote 8 and 16-bit integers to float32 instead of float64 during convolutions.
- the `_get_output` utility was changed to allocate the output arrays as `empty` rather than `zeros` which is more efficient.
- the `_run_1d_filters` utility was improved so that it avoids an extra copy when the number of filters is even

I will submit PRs for these to CuPy. The `_get_output` change in particular also impacts other morphology and interpolation functions that we use in cuCIM, so I have vendored those here as well. For reference regarding these non-filtering changes, there is modest performance improvement. Grayscale `erosion` on a 4k image became 10% faster and `resize` of a 4k image to HD became 20% faster.

It is not easy to review kernels based on generated code strings, so I have tried to add pretty comprehensive tests covering many kernel sizes, all boundary modes, various dtype combinations, etc.

Authors:
  - Gregory Lee (https://github.com/grlee77)

Approvers:
  - Gigon Bae (https://github.com/gigony)
  - https://github.com/jakirkham

URL: #355
  • Loading branch information
grlee77 authored Aug 3, 2022
1 parent db0026b commit 553418b
Show file tree
Hide file tree
Showing 37 changed files with 5,937 additions and 80 deletions.
139 changes: 139 additions & 0 deletions benchmarks/skimage/bench_convolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
Benchmark locally modified ndimage functions vs. their CuPy counterparts
"""
import cupy as cp
import cupyx.scipy.ndimage as ndi
import pytest
from cupyx.profiler import benchmark

from cucim.skimage._vendored.ndimage import (
convolve1d, correlate1d, gaussian_filter, gaussian_filter1d,
gaussian_gradient_magnitude, gaussian_laplace, laplace, prewitt, sobel,
uniform_filter, uniform_filter1d,
)

d = cp.cuda.Device()


def _get_image(shape, dtype, seed=123):

rng = cp.random.default_rng(seed)
dtype = cp.dtype(dtype)
if dtype.kind == 'b':
image = rng.integers(0, 1, shape, dtype=cp.uint8).astype(bool)
elif dtype.kind in 'iu':
image = rng.integers(0, 128, shape, dtype=dtype)
elif dtype.kind in 'c':
real_dtype = cp.asarray([], dtype=dtype).real.dtype
image = rng.standard_normal(shape, dtype=real_dtype)
image = image + 1j * rng.standard_normal(shape, dtype=real_dtype)
else:
if dtype == cp.float16:
image = rng.standard_normal(shape).astype(dtype)
else:
image = rng.standard_normal(shape, dtype=dtype)
return image


def _compare_implementations(
shape, kernel_size, axis, dtype, mode, cval=0.0, origin=0,
output_dtype=None, kernel_dtype=None, output_preallocated=False,
function=convolve1d, max_duration=1
):
dtype = cp.dtype(dtype)
if kernel_dtype is None:
kernel_dtype = dtype
image = _get_image(shape, dtype)
kernel = _get_image((kernel_size,), kernel_dtype)
kwargs = dict(axis=axis, mode=mode, cval=cval, origin=origin)
if output_dtype is not None:
output_dtype = cp.dtype(output_dtype)
function_ref = getattr(ndi, function.__name__)
if output_preallocated:
if output_dtype is None:
output_dtype = image.dtype
output1 = cp.empty(image.shape, dtype=output_dtype)
output2 = cp.empty(image.shape, dtype=output_dtype)
kwargs.update(dict(output=output1))
perf1 = benchmark(function_ref, (image, kernel), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
kwargs.update(dict(output=output2, algorithm='shared_memory'))
perf2 = benchmark(function, (image, kernel), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
return perf1, perf2
kwargs.update(dict(output=output_dtype))
perf1 = benchmark(function_ref, (image, kernel), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
kwargs.update(dict(output=output_dtype, algorithm='shared_memory'))
perf2 = benchmark(function, (image, kernel), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
return perf1, perf2


def _compare_implementations_other(
shape, dtype, mode, cval=0.0,
output_dtype=None, kernel_dtype=None, output_preallocated=False,
function=convolve1d, func_kwargs={}, max_duration=1,
):
dtype = cp.dtype(dtype)
image = _get_image(shape, dtype)
kwargs = dict(mode=mode, cval=cval)
if func_kwargs:
kwargs.update(func_kwargs)
if output_dtype is not None:
output_dtype = cp.dtype(output_dtype)
function_ref = getattr(ndi, function.__name__)
if output_preallocated:
if output_dtype is None:
output_dtype = image.dtype
output1 = cp.empty(image.shape, dtype=output_dtype)
output2 = cp.empty(image.shape, dtype=output_dtype)
kwargs.update(dict(output=output1))
perf1 = benchmark(function_ref, (image,), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
kwargs.update(dict(output=output1, algorithm='shared_memory'))
perf2 = benchmark(function, (image,), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
return perf1, perf2
kwargs.update(dict(output=output_dtype))
perf1 = benchmark(function_ref, (image,), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
kwargs.update(dict(output=output_dtype, algorithm='shared_memory'))
perf2 = benchmark(function, (image,), kwargs=kwargs, n_warmup=10, n_repeat=10000, max_duration=max_duration)
return perf1, perf2


print("\n\n")
print("function | shape | dtype | mode | kernel size | preallocated | axis | dur (ms), CuPy | dur (ms), cuCIM | acceleration ")
print("---------|-------|-------|------|-------------|--------------|------|----------------|-----------------|--------------")
for function in [convolve1d]:
for shape in [(512, 512), (3840, 2160), (64, 64, 64), (256, 256, 256)]:
for dtype in [cp.float32, cp.uint8]:
for mode in ['nearest']:
for kernel_size in [3, 7, 11, 41]:
for output_preallocated in [False]: # , True]:
for axis in range(len(shape)):
output_dtype = dtype
perf1, perf2 = _compare_implementations(shape=shape, kernel_size=kernel_size, mode=mode, axis=axis, dtype=dtype, output_dtype=output_dtype, output_preallocated=output_preallocated, function=function)
t_elem = perf1.gpu_times * 1000.
t_shared = perf2.gpu_times * 1000.
print(f"{function.__name__} | {shape} | {cp.dtype(dtype).name} | {mode} | {kernel_size=} | prealloc={output_preallocated} | {axis=} | {t_elem.mean():0.3f} +/- {t_elem.std():0.3f} | {t_shared.mean():0.3f} +/- {t_shared.std():0.3f} | {t_elem.mean() / t_shared.mean():0.3f}")


print("function | kwargs | shape | dtype | mode | preallocated | dur (ms), CuPy | dur (ms), cuCIM | acceleration ")
print("---------|--------|-------|-------|------|--------------|----------------|-----------------|--------------")
for function, func_kwargs in [
# (gaussian_filter1d, dict(sigma=1.0, axis=0)),
# (gaussian_filter1d, dict(sigma=1.0, axis=-1)),
# (gaussian_filter1d, dict(sigma=4.0, axis=0)),
# (gaussian_filter1d, dict(sigma=4.0, axis=-1)),
(gaussian_filter, dict(sigma=1.0)),
(gaussian_filter, dict(sigma=4.0)),
(uniform_filter, dict(size=11)),
(prewitt, dict(axis=0)),
(sobel, dict(axis=0)),
(prewitt, dict(axis=-1)),
(sobel, dict(axis=-1)),
]:
for shape in [(512, 512), (3840, 2160), (64, 64, 64), (256, 256, 256)]:
for (dtype, output_dtype) in [(cp.float32, cp.float32), (cp.uint8, cp.float32)]:
for mode in ['nearest']:
for output_preallocated in [False, True]:
perf1, perf2 = _compare_implementations_other(shape=shape, mode=mode, dtype=dtype, output_dtype=output_dtype, output_preallocated=output_preallocated, function=function, func_kwargs=func_kwargs)
t_elem = perf1.gpu_times * 1000.
t_shared = perf2.gpu_times * 1000.
print(f"{function.__name__} | {func_kwargs} | {shape} | {cp.dtype(dtype).name} | {mode} | {output_preallocated} | {t_elem.mean():0.3f} +/- {t_elem.std():0.3f} | {t_shared.mean():0.3f} +/- {t_shared.std():0.3f} | {t_elem.mean() / t_shared.mean():0.3f}")

2 changes: 1 addition & 1 deletion python/cucim/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ line_length = 80
known_first_party = cucim
default_section = THIRDPARTY
forced_separate = test_cucim
skip = .tox,.eggs,ci/templates,build,dist,versioneer.py
skip = .tox,.eggs,ci/templates,build,dist,versioneer.py,ndimage.py
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
multi_line_output = GRID
3 changes: 2 additions & 1 deletion python/cucim/src/cucim/skimage/_shared/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from collections.abc import Iterable

import cupy as cp
from cupyx.scipy import ndimage as ndi
import cucim.skimage._vendored.ndimage as ndi


from .._shared import utils
from .._shared.utils import _supported_float_type, convert_to_float
Expand Down
12 changes: 12 additions & 0 deletions python/cucim/src/cucim/skimage/_vendored/_internal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import math
from functools import reduce
from operator import mul

import cupy
import numpy

Expand Down Expand Up @@ -61,3 +65,11 @@ def _normalize_axis_indices(axes, ndim): # NOQA
res.append(axis)

return tuple(sorted(res))


if hasattr(math, 'prod'):
prod = math.prod
else:

def prod(iterable, *, start=1):
return reduce(mul, iterable, start)
Loading

0 comments on commit 553418b

Please sign in to comment.