Skip to content

Commit

Permalink
add n-dim test
Browse files Browse the repository at this point in the history
  • Loading branch information
santiago-imelio committed May 26, 2024
1 parent 536717c commit 4c74ff3
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions test/nx_signal/filters_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do
assert NxSignal.Filters.median(t, opts) == expected
end

test "performs n-dim median filter" do
t =
Nx.tensor([
[
[31, 11, 17, 13, 1],
[1, 3, 19, 23, 29],
[19, 5, 7, 37, 2]
],
[
[19, 5, 7, 37, 2],
[1, 3, 19, 23, 29],
[31, 11, 17, 13, 1]
],
[
[1, 3, 19, 23, 29],
[31, 11, 17, 13, 1],
[19, 5, 7, 37, 2]
]
])

k1 = {3, 3, 1}
k2 = {3, 3, 3}

expected1 =
Nx.tensor([
[
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0]
],
[
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0]
],
[
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0],
[19.0, 5.0, 17.0, 23.0, 2.0]
]
])

expected2 =
Nx.tensor([
[
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
],
[
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
],
[
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
]
])

assert NxSignal.Filters.median(t, kernel_shape: k1) == expected1
assert NxSignal.Filters.median(t, kernel_shape: k2) == expected2
end

test "raises if kernel_shape is not compatible" do
t1 = Nx.iota({10})
opts1 = [kernel_shape: {5, 5}]
Expand Down

0 comments on commit 4c74ff3

Please sign in to comment.