diff --git a/test/nx_signal/filters_test.exs b/test/nx_signal/filters_test.exs index 8c53867..a439305 100644 --- a/test/nx_signal/filters_test.exs +++ b/test/nx_signal/filters_test.exs @@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do assert NxSignal.Filters.median(t, opts) == expected end + test "performs n-dim median filter" do + t = + Nx.tensor([ + [ + [31, 11, 17, 13, 1], + [1, 3, 19, 23, 29], + [19, 5, 7, 37, 2] + ], + [ + [19, 5, 7, 37, 2], + [1, 3, 19, 23, 29], + [31, 11, 17, 13, 1] + ], + [ + [1, 3, 19, 23, 29], + [31, 11, 17, 13, 1], + [19, 5, 7, 37, 2] + ] + ]) + + k1 = {3, 3, 1} + k2 = {3, 3, 3} + + expected1 = + Nx.tensor([ + [ + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0] + ], + [ + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0] + ], + [ + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0], + [19.0, 5.0, 17.0, 23.0, 2.0] + ] + ]) + + expected2 = + Nx.tensor([ + [ + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0] + ], + [ + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0] + ], + [ + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0], + [11.0, 13.0, 17.0, 17.0, 17.0] + ] + ]) + + assert NxSignal.Filters.median(t, kernel_shape: k1) == expected1 + assert NxSignal.Filters.median(t, kernel_shape: k2) == expected2 + end + test "raises if kernel_shape is not compatible" do t1 = Nx.iota({10}) opts1 = [kernel_shape: {5, 5}]