Skip to content

Commit

Permalink
#18332: Update Batch Norm test file
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Mar 5, 2025
1 parent a387234 commit 75ec194
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
36 changes: 17 additions & 19 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 1024, 1024]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
torch.Size([2, 2, 4096, 4096]),
*(torch.Size([n, c, 32, 32]) for n, c in product([4, 5], [7, 8])),
*(torch.Size([n, c, 23, 23]) for n, c in product([4, 5], [7, 8])),
*(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [1, 2])),
],
)
@pytest.mark.parametrize(
Expand All @@ -36,9 +34,9 @@
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("training", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5])
def test_batch_norm_training_fp32(
@pytest.mark.parametrize("eps", [1.0, 0.0, 1e-05])
@pytest.mark.parametrize("momentum", [0.0, 0.1])
def test_batch_norm_tests_fp32(
input_shapes, check_mean, check_var, weight, bias, eps, device, momentum, training, testing_dtype="float32"
):
in_data, input_tensor = data_gen_with_range_batch_norm(
Expand Down Expand Up @@ -123,8 +121,8 @@ def test_batch_norm_training_fp32(


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("channel_size", [1, 2, 3, 4])
@pytest.mark.parametrize("eps", [1.0, 1e-05])
@pytest.mark.parametrize("channel_size", [1, 4])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_BN_fp32_full_value(device, channel_size, eps, weight, bias):
Expand Down Expand Up @@ -170,9 +168,9 @@ def test_BN_fp32_full_value(device, channel_size, eps, weight, bias):
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])),
*(torch.Size([n, c, 32, 32]) for n, c in product([3, 4], [3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([3, 4], [3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [3, 4])),
],
)
@pytest.mark.parametrize(
Expand All @@ -186,7 +184,7 @@ def test_BN_fp32_full_value(device, channel_size, eps, weight, bias):
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("eps", [1.0, 0.0, 1e-05])
def test_batch_norm_fp32(
input_shapes, check_mean, check_var, weight, bias, eps, device, training=False, testing_dtype="float32"
):
Expand Down Expand Up @@ -245,9 +243,9 @@ def test_batch_norm_fp32(
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])),
*(torch.Size([n, c, 32, 32]) for n, c in product([3, 4], [3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([3, 4], [3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [3, 4])),
],
)
@pytest.mark.parametrize(
Expand All @@ -265,8 +263,8 @@ def test_batch_norm_fp32(
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34])
@pytest.mark.parametrize("momentum", [0.0, 0.5])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ ALWI void batchnorm_bcast_tiles(
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_batch_var, i, i * 2);
}
add_binary_tile_init();
copy_tile_to_dst_init_short_with_dt(cb_batch_var, cb_eps);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_eps, i, i * 2 + 1);

add_binary_tile_init();
add_binary_tile(i * 2, i * 2 + 1);
rsqrt_tile_init();
}
rsqrt_tile_init();
for (uint32_t i = 0; i < onetile; ++i) {
rsqrt_tile(i * 2);

pack_tile(i * 2, cb_den);
Expand All @@ -75,20 +77,20 @@ ALWI void batchnorm_bcast_tiles(
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_other, i, i * 2);
}
sub_binary_tile_init();
copy_tile_to_dst_init_short_with_dt(cb_other, cb_bcast);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_bcast, i, i * 2 + 1);
sub_binary_tile_init();
sub_binary_tile(i * 2, i * 2 + 1);
}
cb_pop_front(cb_other, onetile);

//(input - batch_mean)/(sqrt(batch_var + eps))
cb_reserve_back(cb_affine_or_out, onetile);
mul_binary_tile_init();
copy_tile_to_dst_init_short_with_dt(cb_bcast, cb_den);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_den, i, i * 2 + 1);
mul_binary_tile_init();
mul_binary_tile(i * 2, i * 2 + 1);

pack_tile(i * 2, cb_affine_or_out);
Expand All @@ -106,10 +108,10 @@ ALWI void batchnorm_bcast_tiles(
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_affine_or_out, i, i * 2);
}
mul_binary_tile_init();
copy_tile_to_dst_init_short_with_dt(cb_affine_or_out, cb_weight);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_weight, i, i * 2 + 1);
mul_binary_tile_init();
mul_binary_tile(i * 2, i * 2 + 1);

pack_tile(i * 2, cb_scaled_output);
Expand All @@ -129,10 +131,10 @@ ALWI void batchnorm_bcast_tiles(
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp_1, i, i * 2);
}
add_binary_tile_init();
copy_tile_to_dst_init_short_with_dt(cb_tmp_1, cb_bias);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_bias, i, i * 2 + 1);
add_binary_tile_init();
add_binary_tile(i * 2, i * 2 + 1);

pack_tile(i * 2, cb_output_0);
Expand Down

0 comments on commit 75ec194

Please sign in to comment.