diff --git a/tests/ttnn/unit_tests/operations/test_layer_norm.py b/tests/ttnn/unit_tests/operations/test_layer_norm.py index b803e8da4f8..026d2e104cc 100644 --- a/tests/ttnn/unit_tests/operations/test_layer_norm.py +++ b/tests/ttnn/unit_tests/operations/test_layer_norm.py @@ -118,3 +118,45 @@ def test_layer_norm_with_tile_layout(device, h, w): output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor, 0.9998) + + +@pytest.mark.parametrize("h", [32]) +@pytest.mark.parametrize("w", [8192]) +def test_large_layer_norm(device, h, w): + torch.manual_seed(0) + + torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) + torch_output_tensor = torch.nn.functional.layer_norm(torch_input_tensor, normalized_shape=[w]) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.layer_norm(input_tensor) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, 0.9998) + + +@pytest.mark.parametrize("h", [32]) +@pytest.mark.parametrize("w", [8192]) +def test_large_layer_norm_with_weight_and_bias(device, h, w): + torch.manual_seed(0) + + torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) + torch_weight = torch.rand((w,), dtype=torch.bfloat16) + torch_bias = torch.rand((w,), dtype=torch.bfloat16) + + torch_output_tensor = torch.nn.functional.layer_norm( + torch_input_tensor, normalized_shape=[w], weight=torch_weight, bias=torch_bias + ) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device) + bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.layer_norm(input_tensor, weight=weight, bias=bias) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp index fe76c18e4b4..50e5e6a6c1d 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp @@ -88,12 +88,11 @@ void MAIN { for (uint32_t ncht = 0; ncht < NCHt; ncht++) { constexpr int onetile = 1; constexpr int dst0 = 0; - /* - * X + Y - */ - // E[x] - // aka ∑(x-E[x]) - // cb_x in0 + in1 or in0 + // Start of + // E[x] + // aka ∑(x) + // -------- + // n tile_regs_acquire(); tile_regs_wait(); @@ -125,10 +124,18 @@ void MAIN { tile_regs_release(); cb_push_back(cb_ex, onetile); cb_wait_front(cb_ex, onetile); + // End of + // E[x] + // aka ∑(x) + // -------- + // n cb_wait_front(cb_ex, onetile); + // Start of // Var Calculation - // aka ∑(x-E[x])^2 + // Var(X) = ∑(x-E[x])^2 + // ----------- + // n for (uint32_t wt = 0; wt < Wt; wt += blk) { reconfig_data_format(cb_in, cb_ex); pack_reconfig_data_format(cb_ex2); @@ -151,7 +158,6 @@ void MAIN { } cb_pop_front(cb_inb, blk); #endif - //(x-E[x])^2 square_tile_init(); for (uint32_t j = 0; j < blk; j++) { square_tile(j); @@ -168,8 +174,6 @@ void MAIN { tile_regs_wait(); if (wt > 0) { cb_wait_front(cb_ex2, onetile); - UNPACK(DPRINT << "PREV" << ENDL()); - UNPACK(print_full_tile(cb_ex2, 0)); copy_tile_init(cb_ex2); copy_tile(cb_ex2, 0, dst0); cb_pop_front(cb_ex2, onetile); @@ -182,15 +186,25 @@ void MAIN { } cb_pop_front(cb_xmm, blk); cb_reserve_back(cb_ex2, onetile); + reduce_revert_delta(cb_ex2); tile_regs_commit(); pack_tile(dst0, cb_ex2); - reduce_revert_delta(cb_ex2); cb_push_back(cb_ex2, onetile); tile_regs_release(); } tile_regs_acquire(); tile_regs_wait(); + // End of + // Var Calculation + // Var(X) = ∑(x-E[x])^2 + // ----------- + + // Start of + // Calculation + // 1 + // cb_ex2pe = ------------- + // √(Var(X) + ε) cb_wait_front(cb_ex2, onetile); reconfig_data_format(cb_ex2, cb_eps); @@ -214,7 +228,7 @@ void MAIN { cb_pop_front(cb_ex2, onetile); cb_wait_front(cb_ex2pe, onetile); - // broadcasts the tile since only the column contains the important data + // broadcasts the tile since cb_ex2pe is a column vector that contains the important data tile_regs_acquire(); tile_regs_wait(); reconfig_data_format_srca(cb_ex2pe); @@ -226,19 +240,21 @@ void MAIN { pack_tile(dst0, cb_ex2pe); tile_regs_release(); cb_push_back(cb_ex2pe, onetile); + // End of + // Calculation + // 1 + // cb_ex2pe = ------------- + // √(Var(X) + ε) - // Final Val Calc - - // DPRINT << "WT: " <(cb_gamma, cb_out); for (uint32_t j = 0; j < blk; j++) { unary_bcast(cb_gamma, j, j); } - // UNPACK(DPRINT << "AFTER BROADCAST GAMMA\n\n" << ENDL()); + cb_pop_front(cb_gamma, blk); binary_dest_reuse_tiles_init(cb_xmm); for (uint32_t j = 0; j < blk; j++) { binary_dest_reuse_tiles(cb_xmm, j, j); } - // UNPACK(DPRINT << "AFTER MATH GAMMA\n\n" << ENDL()); tile_regs_commit(); if (!do_beta) { cb_xmm = cb_out; @@ -301,19 +312,17 @@ void MAIN { } cb_push_back(cb_xmm, blk); tile_regs_release(); - // UNPACK(DPRINT << "PUSH AND PACK GAMMA\n\n" << ENDL()); } - // UNPACK(DPRINT << "End b4 BETA\n\n" << ENDL()); if (do_beta) { tile_regs_acquire(); tile_regs_wait(); cb_wait_front(cb_beta, blk); cb_wait_front(cb_xmm, blk); - // UNPACK(DPRINT << "End b4 BETA\n\n" << ENDL()); unary_bcast_init(cb_beta, cb_out); for (uint32_t j = 0; j < blk; j++) { unary_bcast(cb_beta, j, j); } + cb_pop_front(cb_beta, blk); binary_dest_reuse_tiles_init(cb_xmm); for (uint32_t j = 0; j < blk; j++) { binary_dest_reuse_tiles(cb_xmm, j, j); @@ -326,14 +335,12 @@ void MAIN { tile_regs_release(); cb_push_back(cb_out, blk); } - // UNPACK(print_full_tile(cb_out,0)); - // UNPACK(DPRINT << "-|-|"<(0); uint32_t NCHt = get_arg_val(1); uint32_t Wt = get_arg_val(2); @@ -80,16 +79,16 @@ void kernel_main() { // read a ublock of tiles from src to CB, and then push the ublock to unpacker uint32_t offs = 0; - for (uint32_t ncht = 0; ncht < NCHt; ncht++) { - auto read_in0_and_in1[&]() { - for (uint32_t wt = 0; wt < Wt; wt += blk) { - read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk); + auto read_in0_and_in1 = [&]() { + for (uint32_t wt = 0; wt < Wt; wt += blk) { + read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk); #ifdef FUSE_PRE_ADD - // TODO(AP): refactor the ifdefs - read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk); + // TODO(AP): refactor the ifdefs + read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk); #endif - } // wt loop - }; + } // wt loop + }; + for (uint32_t ncht = 0; ncht < NCHt; ncht++) { read_in0_and_in1(); read_in0_and_in1(); #if defined FUSE_GAMMA || defined FUSE_BETA diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/reader_unary_interleaved_ln_three_pass.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/reader_unary_interleaved_ln_three_pass.cpp index 842d976971a..1a9b3ccd405 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/reader_unary_interleaved_ln_three_pass.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/reader_unary_interleaved_ln_three_pass.cpp @@ -80,27 +80,43 @@ void kernel_main() { // read a ublock of tiles from src to CB, and then push the ublock to unpacker uint32_t offs = 0; for (uint32_t ncht = 0; ncht < NCHt; ncht++) { + // Data for Calculating E[X] for (uint32_t wt = 0; wt < Wt; wt += blk) { read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk); } // wt loop - #ifdef FUSE_PRE_ADD for (uint32_t wt = 0; wt < Wt; wt += blk) { read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk); } #endif + // Data for Calculating Variance for (uint32_t wt = 0; wt < Wt; wt += blk) { read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk); #ifdef FUSE_PRE_ADD read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk); #endif } // wt loop + + // Data for calculating the final value for (uint32_t wt = 0; wt < Wt; wt += blk) { read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk); #ifdef FUSE_PRE_ADD read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk); #endif + if (ncht == 0) { +#ifdef FUSE_GAMMA + { + read_row_from_cb(cb_id_gamma, addrg, gamma_tile_bytes, wt, blk); + } +#endif + +#ifdef FUSE_BETA + { + read_row_from_cb(cb_id_beta, addrb, beta_tile_bytes, wt, blk); + } +#endif + } } // wt loop offs += Wt; } // ncht loop diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp index 2843f81ece4..8fc2a889799 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp @@ -142,6 +142,15 @@ operation::ProgramWithCallbacks layernorm_multi_core( // TODO(AP): this will not work for all Wts possibly, but should work for Wt=8, 12, 16, 32 // TODO(AP): can also add support for block_size=7 -> 63, 28 uint32_t WtB = tt::div_up(Wt, block_size) * block_size; // Wt padded to be divisible by block size + bool three_pass_needed = false; + if (gamma.has_value() and beta.has_value() and WtB > 120) { + // In the case that the required space is larger than what can be handeled by the single pass + three_pass_needed = true; + WtB = 120; + } else if (WtB > 200) { + three_pass_needed = true; + WtB = 200; + } uint32_t in0_t = WtB; // cb_x for no pre-add variant, x=a+b for fused pre-add, extra space for some buffering uint32_t in1_t = block_size * 2; // buffer for fused pre-add b tensor uint32_t out0_t = block_size * 2; @@ -162,9 +171,10 @@ operation::ProgramWithCallbacks layernorm_multi_core( uint32_t in3_t = 2; // epsilon coming from reader uint32_t im2_t = 2; // - TT_ASSERT( - W <= TILE_WIDTH * im0_t && - "W exceeds the maximum supported size of tile buffer (kernel limitation right now)."); + /* TT_ASSERT( + W <= TILE_WIDTH * im0_t && + "W exceeds the maximum supported size of tile buffer (kernel limitation right now)."); + */ TT_ASSERT( in0_t % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); @@ -266,12 +276,19 @@ operation::ProgramWithCallbacks layernorm_multi_core( auto use_row_major_kernel = (gamma.has_value() and gamma.value().get_layout() == Layout::ROW_MAJOR) or (beta.has_value() and beta.value().get_layout() == Layout::ROW_MAJOR); + + TT_ASSERT(!use_row_major_kernel or !three_pass_needed, "ROW_MAJOR layout not supported for tensors this large"); + auto reader_kernel_path = use_row_major_kernel + ? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/" + "reader_unary_interleaved_ln_rm_gb.cpp" + : "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/" + "reader_unary_interleaved_ln.cpp"; + reader_kernel_path = three_pass_needed ? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/" + "reader_unary_interleaved_ln_three_pass.cpp" + : reader_kernel_path; auto reader_kernels_id = CreateKernel( program, - use_row_major_kernel ? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/" - "reader_unary_interleaved_ln_rm_gb.cpp" - : "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/" - "reader_unary_interleaved_ln_three_pass.cpp", + reader_kernel_path, all_cores, tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); @@ -287,7 +304,9 @@ operation::ProgramWithCallbacks layernorm_multi_core( // grep auto compute_kernels_id = CreateKernel( program, - "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp", + three_pass_needed + ? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp" + : "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm.cpp", all_cores, tt::tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, @@ -374,7 +393,9 @@ operation::ProgramWithCallbacks layernorm_multi_core( } uint32_t curr_row = 0; - float winv = 1.0f / W; // bcast-w scaler + const auto logical_shape = a.get_logical_shape(); + uint32_t logical_W = shape[-1]; + float winv = 1.0f / logical_W; // bcast-w scaler auto bfloat_winv_value = bfloat16(winv); uint32_t packed_winv_value = pack_two_bfloat16_into_uint32({bfloat_winv_value, bfloat_winv_value}); union {