Skip to content

Commit

Permalink
#0: Use autoincrement stream register to signal work done
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaumanTT committed Mar 4, 2025
1 parent e03382a commit 127b86b
Show file tree
Hide file tree
Showing 30 changed files with 303 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TEST(DeviceCommandTest, AddDispatchWait) {
calculator.add_dispatch_wait();

HostMemDeviceCommand command(calculator.write_offset_bytes());
command.add_dispatch_wait(0, 0, 0);
command.add_dispatch_wait(0, 0, 0, 0);
EXPECT_EQ(command.size_bytes(), command.write_offset_bytes());
}

Expand All @@ -24,7 +24,7 @@ TEST(DeviceCommandTest, AddDispatchWaitWithPrefetchStall) {
calculator.add_dispatch_wait_with_prefetch_stall();

HostMemDeviceCommand command(calculator.write_offset_bytes());
command.add_dispatch_wait_with_prefetch_stall(0, 0, 0);
command.add_dispatch_wait_with_prefetch_stall(0, 0, 0, 0);
EXPECT_EQ(command.size_bytes(), command.write_offset_bytes());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,8 @@ void gen_wait_and_stall_cmd(IDevice* device, vector<uint32_t>& prefetch_cmds, ve

CQDispatchCmd wait;
wait.base.cmd_id = CQ_DISPATCH_CMD_WAIT;
wait.wait.barrier = true;
wait.wait.notify_prefetch = true;
wait.wait.wait = true;
wait.wait.flags = CQ_DISPATCH_CMD_WAIT_FLAG_BARRIER | CQ_DISPATCH_CMD_WAIT_FLAG_NOTIFY_PREFETCH |
CQ_DISPATCH_CMD_WAIT_FLAG_WAIT_MEMORY;
wait.wait.addr = dispatch_wait_addr_g;
wait.wait.count = 0;
add_bare_dispatcher_cmd(dispatch_cmds, wait);
Expand Down
30 changes: 25 additions & 5 deletions tt_metal/api/tt-metalium/command_queue_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ enum class CommandQueueDeviceAddrType : uint8_t {
COMPLETION_Q0_LAST_EVENT = 4,
COMPLETION_Q1_LAST_EVENT = 5,
DISPATCH_S_SYNC_SEM = 6,
DISPATCH_MESSAGE = 7,
UNRESERVED = 8
UNRESERVED = 7
};

enum class CommandQueueHostAddrType : uint8_t {
Expand Down Expand Up @@ -125,12 +124,35 @@ class DispatchMemMap {
tt::tt_metal::hal.get_alignment(tt::tt_metal::HalMemType::HOST);
}

uint32_t get_dispatch_message_offset(uint32_t index) const {
uint32_t get_sync_offset(uint32_t index) const {
TT_ASSERT(index < tt::tt_metal::DispatchSettings::DISPATCH_MESSAGE_ENTRIES);
uint32_t offset = index * hal.get_alignment(HalMemType::L1);
return offset;
}

uint32_t get_dispatch_message_addr_start() const {
return tt::tt_metal::hal.get_noc_overlay_start_addr() +
tt::tt_metal::hal.get_noc_stream_reg_space_size() * get_dispatch_stream_index(0);
}

uint32_t get_dispatch_stream_index(uint32_t index) const {
if (last_core_type == CoreType::WORKER) {
// There are 64 streams. CBs use entries 8-39.
return 48u + index;
} else if (last_core_type == CoreType::ETH) {
// There are 32 streams.
return 16u + index;
} else {
TT_THROW("get_dispatch_starting_stream_index not implemented for core type");
}
}

uint32_t get_dispatch_message_update_offset(uint32_t index) const {
TT_ASSERT(index < tt::tt_metal::DispatchSettings::DISPATCH_MESSAGE_ENTRIES);
return tt::tt_metal::hal.get_noc_stream_reg_space_size() * index +
tt::tt_metal::hal.get_noc_stream_remote_dest_buf_space_available_update_reg_index() * sizeof(uint32_t);
}

private:
DispatchMemMap() = default;

Expand Down Expand Up @@ -169,8 +191,6 @@ class DispatchMemMap {
device_cq_addr_sizes_[dev_addr_idx] = settings.prefetch_q_pcie_rd_ptr_size_;
} else if (dev_addr_type == CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM) {
device_cq_addr_sizes_[dev_addr_idx] = settings.dispatch_s_sync_sem_;
} else if (dev_addr_type == CommandQueueDeviceAddrType::DISPATCH_MESSAGE) {
device_cq_addr_sizes_[dev_addr_idx] = settings.dispatch_message_;
} else {
device_cq_addr_sizes_[dev_addr_idx] = settings.other_ptrs_size;
}
Expand Down
22 changes: 15 additions & 7 deletions tt_metal/api/tt-metalium/cq_commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,21 @@ struct CQDispatchWritePackedLargeCmd {
uint16_t write_offset_index;
} __attribute__((packed));

constexpr uint32_t CQ_DISPATCH_CMD_WAIT_FLAG_NONE = 0x00;
// Issue a write barrier
constexpr uint32_t CQ_DISPATCH_CMD_WAIT_FLAG_BARRIER = 0x01;
// Increment prefetch semaphore
constexpr uint32_t CQ_DISPATCH_CMD_WAIT_FLAG_NOTIFY_PREFETCH = 0x02;
// Wait for a count value on memory.
constexpr uint32_t CQ_DISPATCH_CMD_WAIT_FLAG_WAIT_MEMORY = 0x04;
// Wait for a count value on a stream
constexpr uint32_t CQ_DISPATCH_CMD_WAIT_FLAG_WAIT_STREAM = 0x08;
// Clear a count value on a stream.
constexpr uint32_t CQ_DISPATCH_CMD_WAIT_FLAG_CLEAR_STREAM = 0x10;

struct CQDispatchWaitCmd {
uint8_t barrier; // if true, issue write barrier
uint8_t notify_prefetch; // if true, inc prefetch sem
uint8_t clear_count; // if true, reset count to 0
uint8_t wait; // if true, wait on count value below
uint8_t pad1;
uint16_t pad2;
uint8_t flags; // see above
uint16_t stream; // stream to read/write
uint32_t addr; // address to read
uint32_t count; // wait while address is < count
} __attribute__((packed));
Expand Down Expand Up @@ -270,7 +278,7 @@ struct CQDispatchGoSignalMcastCmd {
uint8_t num_unicast_txns;
uint8_t noc_data_start_index;
uint32_t wait_count;
uint32_t wait_addr;
uint32_t wait_stream; // Index of the stream to wait on
} __attribute__((packed));

struct CQDispatchNotifySlaveGoSignalCmd {
Expand Down
23 changes: 7 additions & 16 deletions tt_metal/api/tt-metalium/device_command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,7 @@ class DeviceCommand {
vector_memcpy_aligned<uint32_t> cmd_vector() const { return this->cmd_region_vector; }

void add_dispatch_wait(
uint8_t barrier,
uint32_t address,
uint32_t count,
uint8_t clear_count = 0,
bool notify_prefetch = false,
bool do_wait = true,
uint8_t dispatcher_type = 0) {
uint32_t flags, uint32_t address, uint32_t stream, uint32_t count, uint8_t dispatcher_type = 0) {
auto initialize_wait_cmds = [&](CQPrefetchCmd* relay_wait, CQDispatchCmd* wait_cmd) {
relay_wait->base.cmd_id = CQ_PREFETCH_CMD_RELAY_INLINE;
relay_wait->relay_inline.dispatcher_type = dispatcher_type;
Expand All @@ -105,12 +99,10 @@ class DeviceCommand {
tt::align(sizeof(CQDispatchCmd) + sizeof(CQPrefetchCmd), this->pcie_alignment);

wait_cmd->base.cmd_id = CQ_DISPATCH_CMD_WAIT;
wait_cmd->wait.barrier = barrier;
wait_cmd->wait.notify_prefetch = notify_prefetch;
wait_cmd->wait.wait = do_wait;
wait_cmd->wait.flags = flags;
wait_cmd->wait.addr = address;
wait_cmd->wait.count = count;
wait_cmd->wait.clear_count = clear_count;
wait_cmd->wait.stream = stream;
};
CQPrefetchCmd* relay_wait_dst = this->reserve_space<CQPrefetchCmd*>(sizeof(CQPrefetchCmd));
CQDispatchCmd* wait_cmd_dst = this->reserve_space<CQDispatchCmd*>(sizeof(CQDispatchCmd));
Expand All @@ -127,9 +119,8 @@ class DeviceCommand {
this->cmd_write_offsetB = tt::align(this->cmd_write_offsetB, this->pcie_alignment);
}

void add_dispatch_wait_with_prefetch_stall(
uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool do_wait = true) {
this->add_dispatch_wait(barrier, address, count, clear_count, true, do_wait);
void add_dispatch_wait_with_prefetch_stall(uint32_t flags, uint32_t address, uint32_t stream, uint32_t count) {
this->add_dispatch_wait(flags | CQ_DISPATCH_CMD_WAIT_FLAG_NOTIFY_PREFETCH, address, stream, count);
uint32_t increment_sizeB = tt::align(sizeof(CQPrefetchCmd), this->pcie_alignment);
auto initialize_stall_cmd = [&](CQPrefetchCmd* stall_cmd) {
*stall_cmd = {};
Expand Down Expand Up @@ -280,7 +271,7 @@ class DeviceCommand {
void add_dispatch_go_signal_mcast(
uint32_t wait_count,
uint32_t go_signal,
uint32_t wait_addr,
uint32_t wait_stream,
uint8_t num_mcast_txns,
uint8_t num_unicast_txns,
uint8_t noc_data_start_index,
Expand Down Expand Up @@ -308,7 +299,7 @@ class DeviceCommand {
mcast_cmd->mcast.num_mcast_txns = num_mcast_txns;
mcast_cmd->mcast.num_unicast_txns = num_unicast_txns;
mcast_cmd->mcast.noc_data_start_index = noc_data_start_index;
mcast_cmd->mcast.wait_addr = wait_addr;
mcast_cmd->mcast.wait_stream = wait_stream;
};
CQDispatchCmd* mcast_cmd_dst = this->reserve_space<CQDispatchCmd*>(sizeof(CQDispatchCmd));

Expand Down
8 changes: 3 additions & 5 deletions tt_metal/api/tt-metalium/dispatch_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class DispatchSettings {
uint32_t prefetch_q_rd_ptr_size_{0}; // configured with alignment
uint32_t prefetch_q_pcie_rd_ptr_size_; // configured with alignment
uint32_t dispatch_s_sync_sem_; // configured with alignment
uint32_t dispatch_message_; // configured with alignment
uint32_t other_ptrs_size; // configured with alignment

// cq_prefetch
Expand All @@ -185,9 +184,9 @@ class DispatchSettings {
bool operator==(const DispatchSettings& other) const {
return num_hw_cqs_ == other.num_hw_cqs_ && prefetch_q_rd_ptr_size_ == other.prefetch_q_rd_ptr_size_ &&
prefetch_q_pcie_rd_ptr_size_ == other.prefetch_q_pcie_rd_ptr_size_ &&
dispatch_s_sync_sem_ == other.dispatch_s_sync_sem_ && dispatch_message_ == other.dispatch_message_ &&
other_ptrs_size == other.other_ptrs_size && prefetch_q_entries_ == other.prefetch_q_entries_ &&
prefetch_q_size_ == other.prefetch_q_size_ && prefetch_max_cmd_size_ == other.prefetch_max_cmd_size_ &&
dispatch_s_sync_sem_ == other.dispatch_s_sync_sem_ && other_ptrs_size == other.other_ptrs_size &&
prefetch_q_entries_ == other.prefetch_q_entries_ && prefetch_q_size_ == other.prefetch_q_size_ &&
prefetch_max_cmd_size_ == other.prefetch_max_cmd_size_ &&
prefetch_cmddat_q_size_ == other.prefetch_cmddat_q_size_ &&
prefetch_scratch_db_size_ == other.prefetch_scratch_db_size_ &&
prefetch_d_buffer_size_ == other.prefetch_d_buffer_size_ &&
Expand Down Expand Up @@ -275,7 +274,6 @@ class DispatchSettings {
this->prefetch_q_rd_ptr_size_ = sizeof(prefetch_q_ptr_type);
this->prefetch_q_pcie_rd_ptr_size_ = l1_alignment - sizeof(prefetch_q_ptr_type);
this->dispatch_s_sync_sem_ = DISPATCH_MESSAGE_ENTRIES * l1_alignment;
this->dispatch_message_ = DISPATCH_MESSAGE_ENTRIES * l1_alignment;
this->other_ptrs_size = l1_alignment;

return *this;
Expand Down
8 changes: 8 additions & 0 deletions tt_metal/api/tt-metalium/hal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class Hal {
uint32_t noc_stream_reg_space_size_;
uint32_t noc_stream_remote_dest_buf_size_reg_index_;
uint32_t noc_stream_remote_dest_buf_start_reg_index_;
uint32_t noc_stream_remote_dest_buf_space_available_reg_index_;
uint32_t noc_stream_remote_dest_buf_space_available_update_reg_index_;
bool coordinate_virtualization_enabled_;
uint32_t virtual_worker_start_x_;
uint32_t virtual_worker_start_y_;
Expand Down Expand Up @@ -200,6 +202,12 @@ class Hal {
uint32_t get_noc_stream_remote_dest_buf_start_reg_index() const {
return noc_stream_remote_dest_buf_start_reg_index_;
}
uint32_t get_noc_stream_remote_dest_buf_space_available_reg_index() const {
return noc_stream_remote_dest_buf_space_available_reg_index_;
}
uint32_t get_noc_stream_remote_dest_buf_space_available_update_reg_index() const {
return noc_stream_remote_dest_buf_space_available_update_reg_index_;
}

float get_eps() const { return eps_; }
float get_nan() const { return nan_; }
Expand Down
9 changes: 2 additions & 7 deletions tt_metal/distributed/mesh_workload_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ void write_go_signal(
run_program_go_signal.master_x = dispatch_core.x;
run_program_go_signal.master_y = dispatch_core.y;
run_program_go_signal.dispatch_message_offset =
(uint8_t)DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(sub_device_index);

uint32_t dispatch_message_addr =
DispatchMemMap::get(dispatch_core_type)
.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE) +
DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(sub_device_index);
(uint8_t)DispatchMemMap::get(dispatch_core_type).get_dispatch_message_update_offset(sub_device_index);

// When running with dispatch_s enabled:
// - dispatch_d must notify dispatch_s that a go signal can be sent
Expand All @@ -65,7 +60,7 @@ void write_go_signal(
go_signal_cmd_sequence.add_dispatch_go_signal_mcast(
expected_num_workers_completed,
*reinterpret_cast<uint32_t*>(&run_program_go_signal),
dispatch_message_addr,
DispatchMemMap::get(dispatch_core_type).get_dispatch_stream_index(sub_device_index),
send_mcast ? device->num_noc_mcast_txns(sub_device_id) : 0,
send_unicasts ? ((num_unicast_txns > 0) ? num_unicast_txns : device->num_noc_unicast_txns(sub_device_id)) : 0,
device->noc_data_start_index(sub_device_id, send_mcast, send_unicasts), /* noc_data_start_idx */
Expand Down
32 changes: 22 additions & 10 deletions tt_metal/hw/firmware/src/brisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
#include "circular_buffer_init.h"
#include "dataflow_api.h"
#include "dev_mem_map.h"
#include "noc_overlay_parameters.h"

#include "debug/watcher_common.h"
#include "debug/waypoint.h"
#include "debug/dprint.h"
#include "debug/stack_usage.h"

#include "debug/ring_buffer.h"
// clang-format on

uint8_t noc_index;
Expand Down Expand Up @@ -342,6 +344,8 @@ inline void finish_ncrisc_copy_and_run(dispatch_core_processor_masks enables) {
}
#endif
}
// uint32_t last_iteration_start = 0;
uint32_t last_iteration_start __attribute__((used)) = 0;

inline void start_ncrisc_kernel_run(dispatch_core_processor_masks enables) {
#ifdef NCRISC_FIRMWARE_KERNEL_SPLIT
Expand Down Expand Up @@ -439,18 +443,20 @@ int main() {
NOC_X(mailboxes->go_message.master_x),
NOC_Y(mailboxes->go_message.master_y),
DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset);
// WATCHER_RING_BUFFER_PUSH(DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset);
mailboxes->go_message.signal = RUN_MSG_DONE;
// Notify dispatcher that this has been done
DEBUG_SANITIZE_NOC_ADDR(noc_index, dispatch_addr, 4);
noc_fast_atomic_increment(
noc_fast_write_dw_inline<DM_DEDICATED_NOC>(
noc_index,
NCRISC_AT_CMD_BUF,
1 << REMOTE_DEST_BUF_WORDS_FREE_INC,
dispatch_addr,
0xF, // byte-enable
NOC_UNICAST_WRITE_VC,
1,
31 /*wrap*/,
false /*linked*/,
post_atomic_increments /*posted*/);
false, // mcast
true // posted
);
}
}

Expand Down Expand Up @@ -525,6 +531,8 @@ int main() {
int index = static_cast<std::underlying_type<TensixProcessorTypes>::type>(TensixProcessorTypes::DM0);
void (*kernel_address)(uint32_t) = (void (*)(uint32_t))
(kernel_config_base + launch_msg_address->kernel_config.kernel_text_offset[index]);
// WATCHER_RING_BUFFER_PUSH(
// 0x08000000 | (memory_read(RISCV_DEBUG_REG_WALL_CLOCK_L) - last_iteration_start));
(*kernel_address)((uint32_t)kernel_address);
RECORD_STACK_USAGE();
} else {
Expand All @@ -550,6 +558,8 @@ int main() {
WAYPOINT("D");

wait_ncrisc_trisc();
uint32_t before_wait = memory_read(RISCV_DEBUG_REG_WALL_CLOCK_L);
last_iteration_start = before_wait;

trigger_sync_register_init();

Expand Down Expand Up @@ -592,15 +602,17 @@ int main() {
// messages in the ring buffer. Must be executed before the atomic increment, as after that the launch
// message is no longer owned by us.
CLEAR_PREVIOUS_LAUNCH_MESSAGE_ENTRY_FOR_WATCHER();
noc_fast_atomic_increment(
// WATCHER_RING_BUFFER_PUSH(DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset);
noc_fast_write_dw_inline<DM_DEDICATED_NOC>(
noc_index,
NCRISC_AT_CMD_BUF,
1 << REMOTE_DEST_BUF_WORDS_FREE_INC,
dispatch_addr,
0xF, // byte-enable
NOC_UNICAST_WRITE_VC,
1,
31 /*wrap*/,
false /*linked*/,
post_atomic_increments /*posted*/);
false, // mcast
true // posted
);
mailboxes->launch_msg_rd_ptr = (launch_msg_rd_ptr + 1) & (launch_msg_buffer_num_entries - 1);
}
}
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/hw/firmware/src/brisck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
#if defined ALIGN_LOCAL_CBS_TO_REMOTE_CBS
#include "remote_circular_buffer_api.h"
#endif
#include "debug/ring_buffer.h"

extern uint32_t last_iteration_start;

void kernel_launch(uint32_t kernel_base_addr) {
#if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL)
wait_for_go_message();
WATCHER_RING_BUFFER_PUSH(0x08000000 | (memory_read(RISCV_DEBUG_REG_WALL_CLOCK_L) - last_iteration_start));
#ifdef KERNEL_RUN_TIME
uint64_t end_time = c_tensix_core::read_wall_clock() + KERNEL_RUN_TIME;
while (c_tensix_core::read_wall_clock() < end_time);
Expand Down
11 changes: 6 additions & 5 deletions tt_metal/hw/firmware/src/idle_erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,16 @@ int main() {
DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset);
DEBUG_SANITIZE_NOC_ADDR(noc_index, dispatch_addr, 4);
CLEAR_PREVIOUS_LAUNCH_MESSAGE_ENTRY_FOR_WATCHER();
noc_fast_atomic_increment(
noc_fast_write_dw_inline<DM_DEDICATED_NOC>(
noc_index,
NCRISC_AT_CMD_BUF,
1 << REMOTE_DEST_BUF_WORDS_FREE_INC,
dispatch_addr,
0xF, // byte-enable
NOC_UNICAST_WRITE_VC,
1,
31 /*wrap*/,
false /*linked*/,
true /*posted*/);
false, // mcast
true // posted
);
mailboxes->launch_msg_rd_ptr = (launch_msg_rd_ptr + 1) & (launch_msg_buffer_num_entries - 1);
}
}
Expand Down
Loading

0 comments on commit 127b86b

Please sign in to comment.