Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Gigon Bae <gigony@gmail.com>
  • Loading branch information
grlee77 and gigony authored Feb 26, 2025
1 parent c826ea4 commit 0b76a21
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/cucim/src/cucim/skimage/_shared/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def pdist_max_blockwise(
requirement. The memory used at runtime will be proportional to
``coords_per_block**2``.
A block size of >= 2000 is recommended to overhead poor GPU resource usage
A block size of >= 2000 is recommended to avoid poor GPU resource usage
and to reduce kernel launch overhead.
Parameters
----------
coords : np.ndarray (num_points, ndim)
coords : numpy.ndarray or cupy.ndarray of shape (num_points, ndim)
The coordinates to process.
metric : str, optional
Can be any metric supported by `scipy.spatial.distance.cdist`. The
Expand All @@ -54,7 +54,7 @@ def pdist_max_blockwise(
Internally, calls to cdist will be made with subsets of coords where
the subset size is (coords_per_block, ndim).
compute_argmax : bool, optional
If True, the value of the coordate indices corresponding to the maxima
If True, the value of the cooridate indices corresponding to the maxima
is returned as the second return Value. Otherwise that value will be
``None``.
cdist_kwargs = dict, optional
Expand Down Expand Up @@ -97,7 +97,7 @@ def pdist_max_blockwise(
if _distance_on_cpu:
warnings.warn(
"cuVS >= 25.02 or pylibraft < 24.12 must be installed to use "
"GPU-accelerated pairwaise distance computations. Falling back "
"GPU-accelerated pairwise distance computations. Falling back "
"to SciPy-based CPU implementation."
)
xp = np
Expand All @@ -106,6 +106,14 @@ def pdist_max_blockwise(
xp = cp
coords = cp.asarray(coords)

if not isinstance(coords, (np.ndarray, cp.ndarray)):
raise TypeError("coords must be a numpy or cupy array")

if coords.ndim != 2:
raise ValueError(
f"coords must be a 2-dimensional array, got shape {coords.shape}"
)

num_coords, _ = coords.shape
if num_coords == 0:
raise RuntimeError("No coordinates to process")
Expand Down

0 comments on commit 0b76a21

Please sign in to comment.