Skip to content

Commit

Permalink
Generalize median filter to N dims (#21)
Browse files Browse the repository at this point in the history
* generalize median filter to N dims

* run mix format

* add n-dim test
  • Loading branch information
santiago-imelio authored Aug 16, 2024
1 parent d0b7df4 commit 3088fd4
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 41 deletions.
42 changes: 21 additions & 21 deletions lib/nx_signal/filters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,24 @@ defmodule NxSignal.Filters do
import Nx.Defn

@doc ~S"""
Performs a median filter on a rank 1 or rank 2 tensor.
Performs a median filter on a tensor.
## Options
* `:kernel_shape` - the shape of the sliding window.
It must be compatible with the shape of the tensor.
"""
@doc type: :filters
deftransform median(t = %Nx.Tensor{shape: {length}}, opts) do
defn median(t, opts) do
validate_median_opts!(t, opts)
{kernel_length} = opts[:kernel_shape]

median(Nx.reshape(t, {1, length}), kernel_shape: {1, kernel_length})
|> Nx.squeeze()
end

deftransform median(t = %Nx.Tensor{shape: {_h, _w}}, opts) do
validate_median_opts!(t, opts)
median_n(t, opts)
end

deftransform median(_t, _opts),
do: raise(ArgumentError, message: "tensor must be of rank 1 or 2")

defn median_n(t, opts) do
{k0, k1} = opts[:kernel_shape]

idx =
Nx.stack([Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)], axis: -1)
|> Nx.reshape({:auto, 2})
t
|> idx_tensor()
|> Nx.vectorize(:elements)

t
|> Nx.slice([idx[0], idx[1]], [k0, k1])
|> Nx.slice(start_indices(t, idx), kernel_lengths(opts[:kernel_shape]))
|> Nx.median()
|> Nx.devectorize(keep_names: false)
|> Nx.reshape(t.shape)
Expand All @@ -52,4 +36,20 @@ defmodule NxSignal.Filters do
raise ArgumentError, message: "kernel shape must be of the same rank as the tensor"
end
end

deftransformp idx_tensor(t) do
t
|> Nx.axes()
|> Enum.map(&Nx.iota(t.shape, axis: &1))
|> Nx.stack(axis: -1)
|> Nx.reshape({:auto, length(Nx.axes(t))})
end

deftransformp start_indices(t, idx_tensor) do
t
|> Nx.axes()
|> Enum.map(&idx_tensor[&1])
end

deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape)
end
85 changes: 65 additions & 20 deletions test/nx_signal/filters_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do
assert NxSignal.Filters.median(t, opts) == expected
end

test "performs n-dim median filter" do
t =
Nx.tensor([
[
[31, 11, 17, 13, 1],
[1, 3, 19, 23, 29],
[19, 5, 7, 37, 2]
],
[
[19, 5, 7, 37, 2],
[1, 3, 19, 23, 29],
[31, 11, 17, 13, 1]
],
[
[1, 3, 19, 23, 29],
[31, 11, 17, 13, 1],
[19, 5, 7, 37, 2]
]
])

k1 = {3, 3, 1}
k2 = {3, 3, 3}

expected1 =
Nx.tensor([
[
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0]
],
[
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0]
],
[
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0]
]
])

expected2 =
Nx.tensor([
[
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
],
[
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
],
[
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
]
])

assert NxSignal.Filters.median(t, kernel_shape: k1) == expected1
assert NxSignal.Filters.median(t, kernel_shape: k2) == expected2
end

test "raises if kernel_shape is not compatible" do
t1 = Nx.iota({10})
opts1 = [kernel_shape: {5, 5}]
Expand All @@ -50,25 +115,5 @@ defmodule NxSignal.FiltersTest do
fn -> NxSignal.Filters.median(t2, opts2) end
)
end

test "raises if tensor rank is not 1 or 2" do
t1 = Nx.tensor(1)
opts1 = [kernel_shape: {1}]

assert_raise(
ArgumentError,
"tensor must be of rank 1 or 2",
fn -> NxSignal.Filters.median(t1, opts1) end
)

t2 = Nx.iota({5, 5, 5})
opts2 = [kernel_shape: {3, 3, 3}]

assert_raise(
ArgumentError,
"tensor must be of rank 1 or 2",
fn -> NxSignal.Filters.median(t2, opts2) end
)
end
end
end

0 comments on commit 3088fd4

Please sign in to comment.