From 75ec19459e64af2bdc33cf5a525f82f03fa52821 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Wed, 5 Mar 2025 01:59:22 +0000 Subject: [PATCH] #18332: Update Batch Norm test file --- .../unit_tests/operations/test_batch_norm.py | 36 +++++++++---------- .../compute/batch_norm_sfpu_kernel.cpp | 14 ++++---- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index 536de0790c1..94720cd2caa 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -17,11 +17,9 @@ @pytest.mark.parametrize( "input_shapes", [ - *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - *(torch.Size([n, c, 1024, 1024]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - torch.Size([2, 2, 4096, 4096]), + *(torch.Size([n, c, 32, 32]) for n, c in product([4, 5], [7, 8])), + *(torch.Size([n, c, 23, 23]) for n, c in product([4, 5], [7, 8])), + *(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [1, 2])), ], ) @pytest.mark.parametrize( @@ -36,9 +34,9 @@ @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("training", [True, False]) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5]) -def test_batch_norm_training_fp32( +@pytest.mark.parametrize("eps", [1.0, 0.0, 1e-05]) +@pytest.mark.parametrize("momentum", [0.0, 0.1]) +def test_batch_norm_tests_fp32( input_shapes, check_mean, check_var, weight, bias, eps, device, momentum, training, testing_dtype="float32" ): in_data, input_tensor = data_gen_with_range_batch_norm( @@ -123,8 +121,8 @@ def test_batch_norm_training_fp32( @skip_for_grayskull("Unsupported dtype for Grayskull") -@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -@pytest.mark.parametrize("channel_size", [1, 2, 3, 4]) +@pytest.mark.parametrize("eps", [1.0, 1e-05]) +@pytest.mark.parametrize("channel_size", [1, 4]) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) def test_BN_fp32_full_value(device, channel_size, eps, weight, bias): @@ -170,9 +168,9 @@ def test_BN_fp32_full_value(device, channel_size, eps, weight, bias): @pytest.mark.parametrize( "input_shapes", [ - *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), - *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), - *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])), + *(torch.Size([n, c, 32, 32]) for n, c in product([3, 4], [3, 4])), + *(torch.Size([n, c, 23, 23]) for n, c in product([3, 4], [3, 4])), + *(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [3, 4])), ], ) @pytest.mark.parametrize( @@ -186,7 +184,7 @@ def test_BN_fp32_full_value(device, channel_size, eps, weight, bias): ) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) +@pytest.mark.parametrize("eps", [1.0, 0.0, 1e-05]) def test_batch_norm_fp32( input_shapes, check_mean, check_var, weight, bias, eps, device, training=False, testing_dtype="float32" ): @@ -245,9 +243,9 @@ def test_batch_norm_fp32( @pytest.mark.parametrize( "input_shapes", [ - *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), - *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), - *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])), + *(torch.Size([n, c, 32, 32]) for n, c in product([3, 4], [3, 4])), + *(torch.Size([n, c, 23, 23]) for n, c in product([3, 4], [3, 4])), + *(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [3, 4])), ], ) @pytest.mark.parametrize( @@ -265,8 +263,8 @@ def test_batch_norm_fp32( ) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5]) +@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34]) +@pytest.mark.parametrize("momentum", [0.0, 0.5]) def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device): in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) mean_data, mean_tensor = ( diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp index dd28d7f9734..2991d98aed2 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp @@ -42,13 +42,15 @@ ALWI void batchnorm_bcast_tiles( for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_batch_var, i, i * 2); } + add_binary_tile_init(); copy_tile_to_dst_init_short_with_dt(cb_batch_var, cb_eps); for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_eps, i, i * 2 + 1); - add_binary_tile_init(); add_binary_tile(i * 2, i * 2 + 1); - rsqrt_tile_init(); + } + rsqrt_tile_init(); + for (uint32_t i = 0; i < onetile; ++i) { rsqrt_tile(i * 2); pack_tile(i * 2, cb_den); @@ -75,20 +77,20 @@ ALWI void batchnorm_bcast_tiles( for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_other, i, i * 2); } + sub_binary_tile_init(); copy_tile_to_dst_init_short_with_dt(cb_other, cb_bcast); for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_bcast, i, i * 2 + 1); - sub_binary_tile_init(); sub_binary_tile(i * 2, i * 2 + 1); } cb_pop_front(cb_other, onetile); //(input - batch_mean)/(sqrt(batch_var + eps)) cb_reserve_back(cb_affine_or_out, onetile); + mul_binary_tile_init(); copy_tile_to_dst_init_short_with_dt(cb_bcast, cb_den); for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_den, i, i * 2 + 1); - mul_binary_tile_init(); mul_binary_tile(i * 2, i * 2 + 1); pack_tile(i * 2, cb_affine_or_out); @@ -106,10 +108,10 @@ ALWI void batchnorm_bcast_tiles( for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_affine_or_out, i, i * 2); } + mul_binary_tile_init(); copy_tile_to_dst_init_short_with_dt(cb_affine_or_out, cb_weight); for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_weight, i, i * 2 + 1); - mul_binary_tile_init(); mul_binary_tile(i * 2, i * 2 + 1); pack_tile(i * 2, cb_scaled_output); @@ -129,10 +131,10 @@ ALWI void batchnorm_bcast_tiles( for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_tmp_1, i, i * 2); } + add_binary_tile_init(); copy_tile_to_dst_init_short_with_dt(cb_tmp_1, cb_bias); for (uint32_t i = 0; i < onetile; ++i) { copy_tile(cb_bias, i, i * 2 + 1); - add_binary_tile_init(); add_binary_tile(i * 2, i * 2 + 1); pack_tile(i * 2, cb_output_0);