Skip to content

Commit

Permalink
Clean up halo_gather.cpp.
Browse files Browse the repository at this point in the history
Signed-off-by: Nilaykumar Patel <nkpatel@tenstorrent.com>
  • Loading branch information
nkpatel-tt committed Mar 6, 2025
1 parent c24f447 commit f1dc5bc
Showing 1 changed file with 17 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ template <
bool is_block_sharded,
bool is_width_sharded,
bool is_read,
bool is_col_major,
bool is_remote_config,
bool is_reader>
bool is_col_major>
void copy_sticks_async(
const tt_l1_ptr uint16_t* config_data,
const uint16_t my_noc_x,
Expand All @@ -41,14 +39,14 @@ void copy_sticks_async(
const uint32_t out_base_l1_addr) {
int i = 0;
int length = config_data[i + 2];

while (length) {
uint16_t noc_x = ((is_block_sharded && !is_col_major) || is_width_sharded) ? my_noc_x : config_data[i + 0];
uint16_t noc_y = ((is_block_sharded && is_col_major) || is_width_sharded) ? my_noc_y : config_data[i + 1];
length = config_data[i + 2];
i += 3;

const uint64_t base_addr = get_noc_addr(noc_x, noc_y, is_read ? in_base_l1_addr : out_base_l1_addr);

for (uint16_t j = 0; j < length; j += 3) {
uint16_t src_local_idx = config_data[i + j + 0];
uint16_t dst_local_idx = config_data[i + j + 1];
Expand Down Expand Up @@ -123,7 +121,7 @@ void kernel_main() {
}

uint32_t padding_config_l1_addr = get_read_ptr(padding_config_cb_id);
volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(padding_config_l1_addr);
const tt_l1_ptr uint16_t* config_data = reinterpret_cast<const tt_l1_ptr uint16_t*>(padding_config_l1_addr);

const uint64_t padding_l1_addr = get_noc_addr(my_noc_x, my_noc_y, get_read_ptr(pad_cb_id));
const uint32_t dst_base_addr = out_base_l1_addr;
Expand All @@ -146,34 +144,22 @@ void kernel_main() {
}

cb_wait_front(in_cb_id, in_nsticks); // make sure untilized data is available

// copy data as per remote config
if constexpr (remote_config_cb_id) {
uint32_t config_data_l1_addr = get_read_ptr(remote_config_cb_id);
tt_l1_ptr uint16_t const* config_data = reinterpret_cast<tt_l1_ptr uint16_t const*>(config_data_l1_addr);
copy_sticks_async<
stick_nbytes,
input_aligned_page_size,
is_block_sharded,
is_width_sharded,
remote_read,
is_col_major,
true,
is_reader>(config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr);
}
uint32_t config_data_l1_addr = get_read_ptr(remote_config_cb_id);
config_data = reinterpret_cast<const tt_l1_ptr uint16_t*>(config_data_l1_addr);
copy_sticks_async<
stick_nbytes,
input_aligned_page_size,
is_block_sharded,
is_width_sharded,
remote_read,
is_col_major>(config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr);
// copy data as per local config
if constexpr (local_config_cb_id) {
uint32_t config_data_l1_addr = get_read_ptr(local_config_cb_id);
tt_l1_ptr uint16_t const* config_data = reinterpret_cast<tt_l1_ptr uint16_t const*>(config_data_l1_addr);
copy_sticks_async<
stick_nbytes,
input_aligned_page_size,
is_block_sharded,
is_width_sharded,
false,
is_col_major,
false,
is_reader>(config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr);
}
config_data_l1_addr = get_read_ptr(local_config_cb_id);
config_data = reinterpret_cast<const tt_l1_ptr uint16_t*>(config_data_l1_addr);
copy_sticks_async<stick_nbytes, input_aligned_page_size, is_block_sharded, is_width_sharded, false, is_col_major>(
config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr);

noc_async_read_barrier();
noc_async_write_barrier();
Expand Down

0 comments on commit f1dc5bc

Please sign in to comment.