diff --git a/lib/nx_signal/filters.ex b/lib/nx_signal/filters.ex index 23a435e..79591fe 100644 --- a/lib/nx_signal/filters.ex +++ b/lib/nx_signal/filters.ex @@ -5,7 +5,7 @@ defmodule NxSignal.Filters do import Nx.Defn @doc ~S""" - Performs a median filter on a rank 1 or rank 2 tensor. + Performs a median filter on a tensor. ## Options @@ -13,32 +13,16 @@ defmodule NxSignal.Filters do It must be compatible with the shape of the tensor. """ @doc type: :filters - deftransform median(t = %Nx.Tensor{shape: {length}}, opts) do + defn median(t, opts) do validate_median_opts!(t, opts) - {kernel_length} = opts[:kernel_shape] - - median(Nx.reshape(t, {1, length}), kernel_shape: {1, kernel_length}) - |> Nx.squeeze() - end - - deftransform median(t = %Nx.Tensor{shape: {_h, _w}}, opts) do - validate_median_opts!(t, opts) - median_n(t, opts) - end - - deftransform median(_t, _opts), - do: raise(ArgumentError, message: "tensor must be of rank 1 or 2") - - defn median_n(t, opts) do - {k0, k1} = opts[:kernel_shape] idx = - Nx.stack([Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)], axis: -1) - |> Nx.reshape({:auto, 2}) + t + |> idx_tensor() |> Nx.vectorize(:elements) t - |> Nx.slice([idx[0], idx[1]], [k0, k1]) + |> Nx.slice(start_indices(t, idx), kernel_lengths(opts[:kernel_shape])) |> Nx.median() |> Nx.devectorize(keep_names: false) |> Nx.reshape(t.shape) @@ -52,4 +36,20 @@ defmodule NxSignal.Filters do raise ArgumentError, message: "kernel shape must be of the same rank as the tensor" end end + + deftransformp idx_tensor(t) do + t + |> Nx.axes() + |> Enum.map(&Nx.iota(t.shape, axis: &1)) + |> Nx.stack(axis: -1) + |> Nx.reshape({:auto, length(Nx.axes(t))}) + end + + deftransformp start_indices(t, idx_tensor) do + t + |> Nx.axes() + |> Enum.map(&idx_tensor[&1]) + end + + deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape) end diff --git a/test/nx_signal/filters_test.exs b/test/nx_signal/filters_test.exs index 122f83f..a439305 100644 --- a/test/nx_signal/filters_test.exs +++ b/test/nx_signal/filters_test.exs @@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do assert NxSignal.Filters.median(t, opts) == expected end + test "performs n-dim median filter" do + t = + Nx.tensor([ + [ + [31, 11, 17, 13, 1], + [1, 3, 19, 23, 29], + [19, 5, 7, 37, 2] + ], + [ + [19, 5, 7, 37, 2], + [1, 3, 19, 23, 29], + [31, 11, 17, 13, 1] + ], + [ + [1, 3, 19, 23, 29], + [31, 11, 17, 13, 1], + [19, 5, 7, 37, 2] + ] + ]) + + k1 = {3, 3, 1} + k2 = {3, 3, 3} + + expected1 = + Nx.tensor([ + [ + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0] + ], + [ + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0] + ], + [ + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0] + ] + ]) + + expected2 = + Nx.tensor([ + [ + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0] + ], + [ + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0] + ], + [ + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0] + ] + ]) + + assert NxSignal.Filters.median(t, kernel_shape: k1) == expected1 + assert NxSignal.Filters.median(t, kernel_shape: k2) == expected2 + end + test "raises if kernel_shape is not compatible" do t1 = Nx.iota({10}) opts1 = [kernel_shape: {5, 5}] @@ -50,25 +115,5 @@ defmodule NxSignal.FiltersTest do fn -> NxSignal.Filters.median(t2, opts2) end ) end - - test "raises if tensor rank is not 1 or 2" do - t1 = Nx.tensor(1) - opts1 = [kernel_shape: {1}] - - assert_raise( - ArgumentError, - "tensor must be of rank 1 or 2", - fn -> NxSignal.Filters.median(t1, opts1) end - ) - - t2 = Nx.iota({5, 5, 5}) - opts2 = [kernel_shape: {3, 3, 3}] - - assert_raise( - ArgumentError, - "tensor must be of rank 1 or 2", - fn -> NxSignal.Filters.median(t2, opts2) end - ) - end end end