Skip to content

Commit

Permalink
#16171: Preload kernels before receiving go message
Browse files Browse the repository at this point in the history
Add a flag to the enables that lets brisc.cc and erisc.cc start loading kernels
before receiving a go message. Fast dispatch ensures that that the enables will
be set only after all necessary program data is sent to the core.

This allows preparation for the following kernel (including loading NCRISC
IRAM, setting up CBs, and initializing local memory) to happen in parallel with
the round-trip to the dispatcher_s to sync up with the other kernels and ensure
that they're all launched at the same time.
  • Loading branch information
jbaumanTT committed Dec 19, 2024
1 parent 71d7d94 commit 4398162
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 15 deletions.
18 changes: 13 additions & 5 deletions tt_metal/hw/firmware/src/brisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,14 @@ int main() {

WAYPOINT("GW");
uint8_t go_message_signal = RUN_MSG_DONE;
while ((go_message_signal = mailboxes->go_message.signal) != RUN_MSG_GO) {
// kernel_configs.enable is last in the launch message. so other data is
// valid by the time it's set. All multicast data from the dispatcher is
// written in order, so it will arrive in order. We also have a barrier
// before mcasting the launch message (as a hang workaround), which
// ensures that the unicast data will also have been received.
while (
((go_message_signal = mailboxes->go_message.signal) != RUN_MSG_GO) &&
!(mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.enables & DISPATCH_ENABLE_FLAG_PRELOAD)) {
// While the go signal for kernel execution is not sent, check if the worker was signalled
// to reset its launch message read pointer.
if (go_message_signal == RUN_MSG_RESET_READ_PTR) {
Expand Down Expand Up @@ -418,7 +425,9 @@ int main() {
DeviceZoneScopedMainN("BRISC-FW");
uint32_t launch_msg_rd_ptr = mailboxes->launch_msg_rd_ptr;
launch_msg_t* launch_msg_address = &(mailboxes->launch[launch_msg_rd_ptr]);
DeviceValidateProfiler(launch_msg_address->kernel_config.enables);
enum dispatch_core_processor_masks enables = (enum dispatch_core_processor_masks)(
launch_msg_address->kernel_config.enables & ~DISPATCH_ENABLE_FLAG_PRELOAD);
DeviceValidateProfiler(enables);
DeviceZoneSetCounter(launch_msg_address->kernel_config.host_assigned_id);
// Copies from L1 to IRAM on chips where NCRISC has IRAM
uint32_t kernel_config_base = firmware_config_init(mailboxes, ProgrammableCoreType::TENSIX, DISPATCH_CLASS_TENSIX_DM0);
Expand All @@ -432,8 +441,6 @@ int main() {
volatile tt_reg_ptr uint32_t* cfg_regs = core.cfg_regs_base(0);
cfg_regs[RISCV_IC_INVALIDATE_InvalidateAll_ADDR32] = RISCV_IC_BRISC_MASK | RISCV_IC_TRISC_ALL_MASK | RISCV_IC_NCRISC_MASK;

enum dispatch_core_processor_masks enables = (enum dispatch_core_processor_masks)launch_msg_address->kernel_config.enables;

run_triscs(enables);

noc_index = launch_msg_address->kernel_config.brisc_noc_id;
Expand Down Expand Up @@ -483,12 +490,13 @@ int main() {
#endif
// Brisc is responsible for issuing any noc cmds needed when initializing remote cbs
// So have brisc setup remote cb interfaces even when brisc is not in use
if (launch_msg_address->kernel_config.enables) {
if (enables) {
cb_l1_base =
(uint32_t tt_l1_ptr*)(kernel_config_base + launch_msg_address->kernel_config.remote_cb_offset);
uint32_t end_cb_index = launch_msg_address->kernel_config.min_remote_cb_start_index;
experimental::setup_remote_cb_interfaces<true>(cb_l1_base, end_cb_index);
}
wait_for_go_message();
}
WAYPOINT("D");

Expand Down
3 changes: 2 additions & 1 deletion tt_metal/hw/firmware/src/brisck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#endif

void kernel_launch(uint32_t kernel_base_addr) {

#if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL)
wait_for_go_message();
#ifdef KERNEL_RUN_TIME
uint64_t end_time = c_tensix_core::read_wall_clock() + KERNEL_RUN_TIME;
while (c_tensix_core::read_wall_clock() < end_time);
Expand All @@ -40,6 +40,7 @@ void kernel_launch(uint32_t kernel_base_addr) {
#ifdef ALIGN_LOCAL_CBS_TO_REMOTE_CBS
ALIGN_LOCAL_CBS_TO_REMOTE_CBS
#endif
wait_for_go_message();
{
DeviceZoneScopedMainChildN("BRISC-KERNEL");
kernel_main();
Expand Down
10 changes: 7 additions & 3 deletions tt_metal/hw/firmware/src/erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,25 @@ void __attribute__((noinline)) Application(void) {
while (routing_info->routing_enabled) {
// FD: assume that no more host -> remote writes are pending
uint8_t go_message_signal = mailboxes->go_message.signal;
if (go_message_signal == RUN_MSG_GO) {
if ((go_message_signal == RUN_MSG_GO) ||
(mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.enables & DISPATCH_ENABLE_FLAG_PRELOAD)) {
// Only include this iteration in the device profile if the launch message is valid. This is because all workers get a go signal regardless of whether
// they're running a kernel or not. We don't want to profile "invalid" iterations.
DeviceZoneScopedMainN("ERISC-FW");
uint32_t launch_msg_rd_ptr = mailboxes->launch_msg_rd_ptr;
launch_msg_t* launch_msg_address = &(mailboxes->launch[launch_msg_rd_ptr]);
DeviceValidateProfiler(launch_msg_address->kernel_config.enables);
enum dispatch_core_processor_masks enables = (enum dispatch_core_processor_masks)(
launch_msg_address->kernel_config.enables & ~DISPATCH_ENABLE_FLAG_PRELOAD);
DeviceValidateProfiler(enables);
DeviceZoneSetCounter(launch_msg_address->kernel_config.host_assigned_id);
// Note that a core may get "GO" w/ enable false to keep its launch_msg's in sync
enum dispatch_core_processor_masks enables = (enum dispatch_core_processor_masks)launch_msg_address->kernel_config.enables;
if (enables & DISPATCH_CLASS_MASK_ETH_DM0) {
WAYPOINT("R");
firmware_config_init(mailboxes, ProgrammableCoreType::ACTIVE_ETH, DISPATCH_CLASS_ETH_DM0);
kernel_init(0);
WAYPOINT("D");
} else {
wait_for_go_message();
}
mailboxes->go_message.signal = RUN_MSG_DONE;

Expand Down
9 changes: 7 additions & 2 deletions tt_metal/hw/firmware/src/erisck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ extern "C" void wzerorange(uint32_t *start, uint32_t *end);
CBInterface cb_interface[NUM_CIRCULAR_BUFFERS];

extern "C" [[gnu::section(".start")]] void _start(uint32_t) {
DeviceZoneScopedMainChildN("ERISC-KERNEL");

// Clear bss, we write to rtos_context_switch_ptr just below.
extern uint32_t __ldm_bss_start[];
extern uint32_t __ldm_bss_end[];
wzerorange(__ldm_bss_start, __ldm_bss_end);

rtos_context_switch_ptr = (void (*)())RtosTable[0];
tt_l1_ptr mailboxes_t* const mailboxes = (tt_l1_ptr mailboxes_t*)(MEM_MAILBOX_BASE);

while (mailboxes->go_message.signal != RUN_MSG_GO) {
invalidate_l1_cache();
internal_::risc_context_switch();
}
DeviceZoneScopedMainChildN("ERISC-KERNEL");

kernel_main();
}
7 changes: 6 additions & 1 deletion tt_metal/hw/firmware/src/ncrisck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "risc_common.h"
#include "tensix.h"
#include "tensix_types.h"
Expand All @@ -28,8 +30,9 @@ uint32_t noc_nonposted_atomics_acked[NUM_NOCS];
uint32_t noc_posted_writes_num_issued[NUM_NOCS];

void kernel_launch(uint32_t kernel_base_addr) {
DeviceZoneScopedMainChildN("NCRISC-KERNEL");
#if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL)
wait_for_go_message();
DeviceZoneScopedMainChildN("NCRISC-KERNEL");
#ifdef KERNEL_RUN_TIME
uint64_t end_time = c_tensix_core::read_wall_clock() + KERNEL_RUN_TIME;
while (c_tensix_core::read_wall_clock() < KERNEL_RUN_TIME);
Expand All @@ -46,6 +49,8 @@ void kernel_launch(uint32_t kernel_base_addr) {
#ifdef ALIGN_LOCAL_CBS_TO_REMOTE_CBS
ALIGN_LOCAL_CBS_TO_REMOTE_CBS
#endif
wait_for_go_message();
DeviceZoneScopedMainChildN("NCRISC-KERNEL");
kernel_main();
#ifdef UPDATE_REMOTE_CB_CONFIGS_IN_L1
UPDATE_REMOTE_CB_CONFIGS_IN_L1
Expand Down
8 changes: 5 additions & 3 deletions tt_metal/hw/firmware/src/trisck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ volatile tt_reg_ptr uint * mailbox_base[4] = {
};
}

void kernel_launch(uint32_t kernel_base_addr)
{
DeviceZoneScopedMainChildN("TRISC-KERNEL");
void kernel_launch(uint32_t kernel_base_addr) {
#if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL)
wait_for_go_message();
DeviceZoneScopedMainChildN("TRISC-KERNEL");
#ifdef KERNEL_RUN_TIME
ckernel::wait(KERNEL_RUN_TIME);
#endif
Expand All @@ -57,6 +57,8 @@ void kernel_launch(uint32_t kernel_base_addr)
#if !defined(UCK_CHLKC_MATH) and defined ALIGN_LOCAL_CBS_TO_REMOTE_CBS
ALIGN_LOCAL_CBS_TO_REMOTE_CBS
#endif
wait_for_go_message();
DeviceZoneScopedMainChildN("TRISC-KERNEL");
run_kernel();
#if !defined(UCK_CHLKC_MATH) and defined UPDATE_REMOTE_CB_CONFIGS_IN_L1
UPDATE_REMOTE_CB_CONFIGS_IN_L1
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/hw/inc/dev_msgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ struct rta_offset_t {
// Maximums across all archs
constexpr auto NUM_PROGRAMMABLE_CORE_TYPES = 3u;
constexpr auto NUM_PROCESSORS_PER_CORE_TYPE = 5u;
enum dispatchenable_flags : uint8_t {
DISPATCH_ENABLE_FLAG_PRELOAD = 1 << 7,
};

struct kernel_config_msg_t {
volatile uint16_t watcher_kernel_ids[DISPATCH_CLASS_MAX];
Expand All @@ -121,6 +124,7 @@ struct kernel_config_msg_t {
volatile uint8_t max_local_cb_end_index;
volatile uint8_t min_remote_cb_start_index;
volatile uint8_t exit_erisc_kernel;
// Or of enable_flags and dispatch_core_processor_masks.
volatile uint8_t enables;
volatile uint8_t pad2[9];
} __attribute__((packed));
Expand Down
10 changes: 10 additions & 0 deletions tt_metal/hw/inc/firmware_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "dev_msgs.h"
#include "noc/noc_parameters.h"
#include "debug/dprint.h"
#include "risc_common.h"

extern uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS];
extern int32_t bank_to_dram_offset[NUM_DRAM_BANKS];
Expand Down Expand Up @@ -72,3 +73,12 @@ uint32_t firmware_config_init(

return kernel_config_base[core_type_index];
}

FORCE_INLINE
void wait_for_go_message() {
tt_l1_ptr mailboxes_t* const mailboxes = (tt_l1_ptr mailboxes_t*)(MEM_MAILBOX_BASE);

while (mailboxes->go_message.signal != RUN_MSG_GO) {
invalidate_l1_cache();
}
}
2 changes: 2 additions & 0 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,7 @@ void EnqueueProgramCommand::assemble_device_commands(
uint32_t programmable_core_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX);
for (KernelGroup& kernel_group : program.get_kernel_groups(programmable_core_index)) {
kernel_group.launch_msg.kernel_config.mode = DISPATCH_MODE_DEV;
kernel_group.launch_msg.kernel_config.enables |= DISPATCH_ENABLE_FLAG_PRELOAD;
for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) {
kernel_group.launch_msg.kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr;
}
Expand Down Expand Up @@ -1122,6 +1123,7 @@ void EnqueueProgramCommand::assemble_device_commands(
if (programmable_core_index != -1) {
for (KernelGroup& kernel_group : program.get_kernel_groups(programmable_core_index)) {
kernel_group.launch_msg.kernel_config.mode = DISPATCH_MODE_DEV;
kernel_group.launch_msg.kernel_config.enables |= DISPATCH_ENABLE_FLAG_PRELOAD;
for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) {
kernel_group.launch_msg.kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr;
}
Expand Down

0 comments on commit 4398162

Please sign in to comment.