Skip to content

Commit

Permalink
#18332: Remove cb_num buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Mar 5, 2025
1 parent b4a8570 commit adfebf2
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 25 deletions.
8 changes: 3 additions & 5 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,9 @@ def test_batch_norm_fp32(
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 1024, 1024]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
torch.Size([3, 6, 4096, 4096]),
*(torch.Size([n, c, 32, 32]) for n, c in product([4, 5], [7, 8])),
*(torch.Size([n, c, 23, 23]) for n, c in product([4, 5], [7, 8])),
*(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [1, 2])),
],
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,8 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
a_single_tile_size,
num_tiles_per_cb,
a_data_format); // to store 1/(sqrt(batch_var + eps))
auto [num_cb, num_cb_handle] = create_cb(
tt::CBIndex::c_8,
program,
all_device_cores,
a_single_tile_size,
num_tiles_per_cb,
a_data_format); // to store input - batch_mean
auto [temp_1_cb, temp_1_cb_handle] =
create_cb(tt::CBIndex::c_9, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
create_cb(tt::CBIndex::c_8, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto a_is_dram = static_cast<uint32_t>(input_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
Expand Down Expand Up @@ -273,7 +266,6 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
batch_var_tensor_cb,
eps_cb,
den_cb,
num_cb,
weight_tensor_cb,
temp_1_cb,
bias_tensor_cb}) {
Expand All @@ -290,7 +282,6 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
batch_var_tensor_cb,
eps_cb,
den_cb,
num_cb,
weight_tensor_cb,
temp_1_cb,
bias_tensor_cb};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ ALWI void batchnorm_bcast_tiles(
uint32_t cb_batch_var,
uint32_t cb_eps,
uint32_t cb_den,
uint32_t cb_num,
uint32_t cb_weight,
uint32_t cb_bias,
uint32_t cb_tmp_1,
Expand Down Expand Up @@ -142,10 +141,9 @@ void MAIN {
constexpr auto cb_batch_var = get_compile_time_arg_val(5); // batch_var
constexpr auto cb_eps = get_compile_time_arg_val(6); // eps
constexpr auto cb_den = get_compile_time_arg_val(7); // 1/(sqrt(batch_var + eps))
constexpr auto cb_num = get_compile_time_arg_val(8); // input - batch_mean
constexpr auto cb_weight = get_compile_time_arg_val(9); // weight tensor
constexpr auto cb_tmp_1 = get_compile_time_arg_val(10); // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = get_compile_time_arg_val(11); // bias tensor
constexpr auto cb_weight = get_compile_time_arg_val(8); // weight tensor
constexpr auto cb_tmp_1 = get_compile_time_arg_val(9); // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = get_compile_time_arg_val(10); // bias tensor

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand All @@ -167,7 +165,6 @@ void MAIN {
cb_batch_var,
cb_eps,
cb_den,
cb_num,
cb_weight,
cb_bias,
cb_tmp_1,
Expand All @@ -184,7 +181,6 @@ void MAIN {
cb_batch_var,
cb_eps,
cb_den,
cb_num,
cb_weight,
cb_bias,
cb_tmp_1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ void MAIN {
constexpr auto cb_batch_var = get_compile_time_arg_val(5); // batch_var
constexpr auto cb_eps = get_compile_time_arg_val(6); // eps
constexpr auto cb_den = get_compile_time_arg_val(7); // 1/(sqrt(batch_var + eps))
constexpr auto cb_weight = get_compile_time_arg_val(9); // weight tensor
constexpr auto cb_tmp_1 = get_compile_time_arg_val(10); // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = get_compile_time_arg_val(11); // bias tensor
constexpr auto cb_weight = get_compile_time_arg_val(8); // weight tensor
constexpr auto cb_tmp_1 = get_compile_time_arg_val(9); // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = get_compile_time_arg_val(10); // bias tensor

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand Down

0 comments on commit adfebf2

Please sign in to comment.