From adfebf2a0f91fa8ca537c019a4da936e1db3c21a Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Wed, 5 Mar 2025 09:38:32 +0000 Subject: [PATCH] #18332: Remove cb_num buffer --- tests/ttnn/unit_tests/operations/test_batch_norm.py | 8 +++----- .../batch_norm/device/batch_norm_program_factory.cpp | 11 +---------- .../device/kernels/compute/batch_norm_kernel.cpp | 10 +++------- .../device/kernels/compute/batch_norm_sfpu_kernel.cpp | 6 +++--- 4 files changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index aa8aa1e08e4..1f4d4d60452 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -243,11 +243,9 @@ def test_batch_norm_fp32( @pytest.mark.parametrize( "input_shapes", [ - *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - *(torch.Size([n, c, 1024, 1024]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])), - torch.Size([3, 6, 4096, 4096]), + *(torch.Size([n, c, 32, 32]) for n, c in product([4, 5], [7, 8])), + *(torch.Size([n, c, 23, 23]) for n, c in product([4, 5], [7, 8])), + *(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [1, 2])), ], ) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 96ddddc62d1..84e8a6faac0 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -204,15 +204,8 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch a_single_tile_size, num_tiles_per_cb, a_data_format); // to store 1/(sqrt(batch_var + eps)) - auto [num_cb, num_cb_handle] = create_cb( - tt::CBIndex::c_8, - program, - all_device_cores, - a_single_tile_size, - num_tiles_per_cb, - a_data_format); // to store input - batch_mean auto [temp_1_cb, temp_1_cb_handle] = - create_cb(tt::CBIndex::c_9, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); + create_cb(tt::CBIndex::c_8, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); auto a_is_dram = static_cast(input_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto b_is_dram = static_cast(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); @@ -273,7 +266,6 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch batch_var_tensor_cb, eps_cb, den_cb, - num_cb, weight_tensor_cb, temp_1_cb, bias_tensor_cb}) { @@ -290,7 +282,6 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch batch_var_tensor_cb, eps_cb, den_cb, - num_cb, weight_tensor_cb, temp_1_cb, bias_tensor_cb}; diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index 9670212dbe2..c22c34ab2ae 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -17,7 +17,6 @@ ALWI void batchnorm_bcast_tiles( uint32_t cb_batch_var, uint32_t cb_eps, uint32_t cb_den, - uint32_t cb_num, uint32_t cb_weight, uint32_t cb_bias, uint32_t cb_tmp_1, @@ -142,10 +141,9 @@ void MAIN { constexpr auto cb_batch_var = get_compile_time_arg_val(5); // batch_var constexpr auto cb_eps = get_compile_time_arg_val(6); // eps constexpr auto cb_den = get_compile_time_arg_val(7); // 1/(sqrt(batch_var + eps)) - constexpr auto cb_num = get_compile_time_arg_val(8); // input - batch_mean - constexpr auto cb_weight = get_compile_time_arg_val(9); // weight tensor - constexpr auto cb_tmp_1 = get_compile_time_arg_val(10); // (input - batch_mean)/(sqrt(batch_var + eps)) - constexpr auto cb_bias = get_compile_time_arg_val(11); // bias tensor + constexpr auto cb_weight = get_compile_time_arg_val(8); // weight tensor + constexpr auto cb_tmp_1 = get_compile_time_arg_val(9); // (input - batch_mean)/(sqrt(batch_var + eps)) + constexpr auto cb_bias = get_compile_time_arg_val(10); // bias tensor auto cb_bcast = cb_batch_mean; auto cb_other = cb_input; @@ -167,7 +165,6 @@ void MAIN { cb_batch_var, cb_eps, cb_den, - cb_num, cb_weight, cb_bias, cb_tmp_1, @@ -184,7 +181,6 @@ void MAIN { cb_batch_var, cb_eps, cb_den, - cb_num, cb_weight, cb_bias, cb_tmp_1, diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp index 2991d98aed2..525fa1a7b0d 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp @@ -173,9 +173,9 @@ void MAIN { constexpr auto cb_batch_var = get_compile_time_arg_val(5); // batch_var constexpr auto cb_eps = get_compile_time_arg_val(6); // eps constexpr auto cb_den = get_compile_time_arg_val(7); // 1/(sqrt(batch_var + eps)) - constexpr auto cb_weight = get_compile_time_arg_val(9); // weight tensor - constexpr auto cb_tmp_1 = get_compile_time_arg_val(10); // (input - batch_mean)/(sqrt(batch_var + eps)) - constexpr auto cb_bias = get_compile_time_arg_val(11); // bias tensor + constexpr auto cb_weight = get_compile_time_arg_val(8); // weight tensor + constexpr auto cb_tmp_1 = get_compile_time_arg_val(9); // (input - batch_mean)/(sqrt(batch_var + eps)) + constexpr auto cb_bias = get_compile_time_arg_val(10); // bias tensor auto cb_bcast = cb_batch_mean; auto cb_other = cb_input;