Skip to content

Commit

Permalink
#17248: Layernorm now supports any tensor that fits in DRAM (not ROW_…
Browse files Browse the repository at this point in the history
…MAJOR) and added tests to reflect this
  • Loading branch information
vsureshTT committed Mar 5, 2025
1 parent 67e13b2 commit 6bcbd81
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 55 deletions.
42 changes: 42 additions & 0 deletions tests/ttnn/unit_tests/operations/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,45 @@ def test_layer_norm_with_tile_layout(device, h, w):
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.9998)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [8192])
def test_large_layer_norm(device, h, w):
torch.manual_seed(0)

torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.nn.functional.layer_norm(torch_input_tensor, normalized_shape=[w])

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.layer_norm(input_tensor)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.9998)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [8192])
def test_large_layer_norm_with_weight_and_bias(device, h, w):
torch.manual_seed(0)

torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_weight = torch.rand((w,), dtype=torch.bfloat16)
torch_bias = torch.rand((w,), dtype=torch.bfloat16)

torch_output_tensor = torch.nn.functional.layer_norm(
torch_input_tensor, normalized_shape=[w], weight=torch_weight, bias=torch_bias
)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.layer_norm(input_tensor, weight=weight, bias=bias)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.999)
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ void MAIN {
for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
constexpr int onetile = 1;
constexpr int dst0 = 0;
/*
* X + Y
*/
// E[x]
// aka ∑(x-E[x])
// cb_x in0 + in1 or in0
// Start of
// E[x]
// aka ∑(x)
// --------
// n
tile_regs_acquire();
tile_regs_wait();

Expand Down Expand Up @@ -125,10 +124,18 @@ void MAIN {
tile_regs_release();
cb_push_back(cb_ex, onetile);
cb_wait_front(cb_ex, onetile);
// End of
// E[x]
// aka ∑(x)
// --------
// n

cb_wait_front(cb_ex, onetile);
// Start of
// Var Calculation
// aka ∑(x-E[x])^2
// Var(X) = ∑(x-E[x])^2
// -----------
// n
for (uint32_t wt = 0; wt < Wt; wt += blk) {
reconfig_data_format(cb_in, cb_ex);
pack_reconfig_data_format(cb_ex2);
Expand All @@ -151,7 +158,6 @@ void MAIN {
}
cb_pop_front(cb_inb, blk);
#endif
//(x-E[x])^2
square_tile_init();
for (uint32_t j = 0; j < blk; j++) {
square_tile(j);
Expand All @@ -168,8 +174,6 @@ void MAIN {
tile_regs_wait();
if (wt > 0) {
cb_wait_front(cb_ex2, onetile);
UNPACK(DPRINT << "PREV" << ENDL());
UNPACK(print_full_tile(cb_ex2, 0));
copy_tile_init(cb_ex2);
copy_tile(cb_ex2, 0, dst0);
cb_pop_front(cb_ex2, onetile);
Expand All @@ -182,15 +186,25 @@ void MAIN {
}
cb_pop_front(cb_xmm, blk);
cb_reserve_back(cb_ex2, onetile);
reduce_revert_delta(cb_ex2);
tile_regs_commit();
pack_tile(dst0, cb_ex2);
reduce_revert_delta(cb_ex2);
cb_push_back(cb_ex2, onetile);
tile_regs_release();
}

tile_regs_acquire();
tile_regs_wait();
// End of
// Var Calculation
// Var(X) = ∑(x-E[x])^2
// -----------

// Start of
// Calculation
// 1
// cb_ex2pe = -------------
// √(Var(X) + ε)
cb_wait_front(cb_ex2, onetile);

reconfig_data_format(cb_ex2, cb_eps);
Expand All @@ -214,7 +228,7 @@ void MAIN {
cb_pop_front(cb_ex2, onetile);
cb_wait_front(cb_ex2pe, onetile);

// broadcasts the tile since only the column contains the important data
// broadcasts the tile since cb_ex2pe is a column vector that contains the important data
tile_regs_acquire();
tile_regs_wait();
reconfig_data_format_srca(cb_ex2pe);
Expand All @@ -226,19 +240,21 @@ void MAIN {
pack_tile(dst0, cb_ex2pe);
tile_regs_release();
cb_push_back(cb_ex2pe, onetile);
// End of
// Calculation
// 1
// cb_ex2pe = -------------
// √(Var(X) + ε)

// Final Val Calc

// DPRINT << "WT: " <<Wt << " \n\n" << ENDL();
cb_wait_front(cb_ex2pe, 1);
reconfig_data_format(cb_in, cb_ex);
pack_reconfig_data_format(cb_in, cb_out);
// PACK(print_full_tile(cb_ex2pe,0));
// Start of
// Final Val Calc
// x-E[X]
//(---------------*𝛄)+ß
// √(Var(X)+ε)
for (uint32_t wt = 0; wt < Wt; wt += blk) {
if (do_gamma) {
}
if (do_beta) {
}
tile_regs_acquire();
tile_regs_wait();
cb_reserve_back(cb_out, blk);
Expand Down Expand Up @@ -270,25 +286,20 @@ void MAIN {
}
cb_push_back(cb_xmm, blk);
tile_regs_release();
// UNPACK(DPRINT << "End b4 gamma\n\n" << ENDL());
if (do_gamma) {
// UNPACK(DPRINT << "BEFORE ANYTHING GAMMA\n\n" << ENDL());
tile_regs_acquire();
tile_regs_wait();
cb_wait_front(cb_gamma, blk);
// UNPACK(DPRINT << "BEFORE xmm GAMMA\n\n" << ENDL());
cb_wait_front(cb_xmm, blk);
// UNPACK(DPRINT << "BEFORE BROADCAST GAMMA\n\n" << ENDL());
unary_bcast_init<BroadcastType::ROW>(cb_gamma, cb_out);
for (uint32_t j = 0; j < blk; j++) {
unary_bcast<BroadcastType::ROW>(cb_gamma, j, j);
}
// UNPACK(DPRINT << "AFTER BROADCAST GAMMA\n\n" << ENDL());
cb_pop_front(cb_gamma, blk);
binary_dest_reuse_tiles_init<ELWMUL, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_xmm);
for (uint32_t j = 0; j < blk; j++) {
binary_dest_reuse_tiles<ELWMUL, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_xmm, j, j);
}
// UNPACK(DPRINT << "AFTER MATH GAMMA\n\n" << ENDL());
tile_regs_commit();
if (!do_beta) {
cb_xmm = cb_out;
Expand All @@ -301,19 +312,17 @@ void MAIN {
}
cb_push_back(cb_xmm, blk);
tile_regs_release();
// UNPACK(DPRINT << "PUSH AND PACK GAMMA\n\n" << ENDL());
}
// UNPACK(DPRINT << "End b4 BETA\n\n" << ENDL());
if (do_beta) {
tile_regs_acquire();
tile_regs_wait();
cb_wait_front(cb_beta, blk);
cb_wait_front(cb_xmm, blk);
// UNPACK(DPRINT << "End b4 BETA\n\n" << ENDL());
unary_bcast_init<BroadcastType::ROW>(cb_beta, cb_out);
for (uint32_t j = 0; j < blk; j++) {
unary_bcast<BroadcastType::ROW>(cb_beta, j, j);
}
cb_pop_front(cb_beta, blk);
binary_dest_reuse_tiles_init<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_xmm);
for (uint32_t j = 0; j < blk; j++) {
binary_dest_reuse_tiles<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_xmm, j, j);
Expand All @@ -326,14 +335,12 @@ void MAIN {
tile_regs_release();
cb_push_back(cb_out, blk);
}
// UNPACK(print_full_tile(cb_out,0));
// UNPACK(DPRINT << "-|-|"<<ENDL());
// UNPACK(DPRINT << "End b4 BETA\n\n" << ENDL());
// DPRINT << wt << "Before cb_out reserve!\n\n\n" << ENDL();
}
// End of
// Final Val Calc
// x-E[X]
//(---------------*𝛄)+ß
// √(Var(X)+ε)
} // NCHt loop
// cb_pop_front(cb_scaler, 1); // optional for correctness
// cb_pop_front(cb_eps, 1); // optional for correctness
// cb_pop_front(cb_col1, 1); // optional for correctness
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ void read_row_from_cb(
cb_push_back(cb_id, blk);
}
void kernel_main() {
DPRINT << "HIIII \n\n\n" << ENDL();
uint32_t src_addr = get_arg_val<uint32_t>(0);
uint32_t NCHt = get_arg_val<uint32_t>(1);
uint32_t Wt = get_arg_val<uint32_t>(2);
Expand Down Expand Up @@ -80,16 +79,16 @@ void kernel_main() {

// read a ublock of tiles from src to CB, and then push the ublock to unpacker
uint32_t offs = 0;
for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
auto read_in0_and_in1[&]() {
for (uint32_t wt = 0; wt < Wt; wt += blk) {
read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk);
auto read_in0_and_in1 = [&]() {
for (uint32_t wt = 0; wt < Wt; wt += blk) {
read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk);
#ifdef FUSE_PRE_ADD
// TODO(AP): refactor the ifdefs
read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk);
// TODO(AP): refactor the ifdefs
read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk);
#endif
} // wt loop
};
} // wt loop
};
for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
read_in0_and_in1();
read_in0_and_in1();
#if defined FUSE_GAMMA || defined FUSE_BETA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,27 +80,43 @@ void kernel_main() {
// read a ublock of tiles from src to CB, and then push the ublock to unpacker
uint32_t offs = 0;
for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
// Data for Calculating E[X]
for (uint32_t wt = 0; wt < Wt; wt += blk) {
read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk);
} // wt loop

#ifdef FUSE_PRE_ADD
for (uint32_t wt = 0; wt < Wt; wt += blk) {
read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk);
}
#endif

// Data for Calculating Variance
for (uint32_t wt = 0; wt < Wt; wt += blk) {
read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk);
#ifdef FUSE_PRE_ADD
read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk);
#endif
} // wt loop

// Data for calculating the final value
for (uint32_t wt = 0; wt < Wt; wt += blk) {
read_row_from_cb(cb_id_in0, src_a, src0_tile_bytes, offs + wt + tile_offset, blk);
#ifdef FUSE_PRE_ADD
read_row_from_cb(cb_id_in1, src_b, src1_tile_bytes, offs + wt + tile_offset, blk);
#endif
if (ncht == 0) {
#ifdef FUSE_GAMMA
{
read_row_from_cb(cb_id_gamma, addrg, gamma_tile_bytes, wt, blk);
}
#endif

#ifdef FUSE_BETA
{
read_row_from_cb(cb_id_beta, addrb, beta_tile_bytes, wt, blk);
}
#endif
}
} // wt loop
offs += Wt;
} // ncht loop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ operation::ProgramWithCallbacks layernorm_multi_core(
// TODO(AP): this will not work for all Wts possibly, but should work for Wt=8, 12, 16, 32
// TODO(AP): can also add support for block_size=7 -> 63, 28
uint32_t WtB = tt::div_up(Wt, block_size) * block_size; // Wt padded to be divisible by block size
bool three_pass_needed = false;
if (gamma.has_value() and beta.has_value() and WtB > 120) {
// In the case that the required space is larger than what can be handeled by the single pass
three_pass_needed = true;
WtB = 120;
} else if (WtB > 200) {
three_pass_needed = true;
WtB = 200;
}
uint32_t in0_t = WtB; // cb_x for no pre-add variant, x=a+b for fused pre-add, extra space for some buffering
uint32_t in1_t = block_size * 2; // buffer for fused pre-add b tensor
uint32_t out0_t = block_size * 2;
Expand All @@ -162,9 +171,10 @@ operation::ProgramWithCallbacks layernorm_multi_core(
uint32_t in3_t = 2; // epsilon coming from reader
uint32_t im2_t = 2; //

TT_ASSERT(
W <= TILE_WIDTH * im0_t &&
"W exceeds the maximum supported size of tile buffer (kernel limitation right now).");
/* TT_ASSERT(
W <= TILE_WIDTH * im0_t &&
"W exceeds the maximum supported size of tile buffer (kernel limitation right now).");
*/
TT_ASSERT(
in0_t % block_size == 0 &&
"Size of buffer must be divisible by the size of block used by the reader and compute kernel.");
Expand Down Expand Up @@ -266,12 +276,19 @@ operation::ProgramWithCallbacks layernorm_multi_core(

auto use_row_major_kernel = (gamma.has_value() and gamma.value().get_layout() == Layout::ROW_MAJOR) or
(beta.has_value() and beta.value().get_layout() == Layout::ROW_MAJOR);

TT_ASSERT(!use_row_major_kernel or !three_pass_needed, "ROW_MAJOR layout not supported for tensors this large");
auto reader_kernel_path = use_row_major_kernel
? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/"
"reader_unary_interleaved_ln_rm_gb.cpp"
: "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/"
"reader_unary_interleaved_ln.cpp";
reader_kernel_path = three_pass_needed ? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/"
"reader_unary_interleaved_ln_three_pass.cpp"
: reader_kernel_path;
auto reader_kernels_id = CreateKernel(
program,
use_row_major_kernel ? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/"
"reader_unary_interleaved_ln_rm_gb.cpp"
: "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/dataflow/"
"reader_unary_interleaved_ln_three_pass.cpp",
reader_kernel_path,
all_cores,
tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines));

Expand All @@ -287,7 +304,9 @@ operation::ProgramWithCallbacks layernorm_multi_core(
// grep
auto compute_kernels_id = CreateKernel(
program,
"ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp",
three_pass_needed
? "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_three_pass.cpp"
: "ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm.cpp",
all_cores,
tt::tt_metal::ComputeConfig{
.math_fidelity = math_fidelity,
Expand Down Expand Up @@ -374,7 +393,9 @@ operation::ProgramWithCallbacks layernorm_multi_core(
}

uint32_t curr_row = 0;
float winv = 1.0f / W; // bcast-w scaler
const auto logical_shape = a.get_logical_shape();
uint32_t logical_W = shape[-1];
float winv = 1.0f / logical_W; // bcast-w scaler
auto bfloat_winv_value = bfloat16(winv);
uint32_t packed_winv_value = pack_two_bfloat16_into_uint32({bfloat_winv_value, bfloat_winv_value});
union {
Expand Down

0 comments on commit 6bcbd81

Please sign in to comment.