-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add initial support for Llama3-8B on BH #18697
base: main
Are you sure you want to change the base?
Conversation
e3c9847
to
e006a3c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved with minor comments
@@ -118,7 +122,11 @@ def count_devices(output): | |||
|
|||
# Count total PCI devices (ignoring N/A) | |||
total_pci_devices = len( | |||
[line for line in available_boards.split("\n") if ("Wormhole" or "Grayskull" or "Blackhole") in line] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you're here, want to remove Grayskull?
@@ -81,9 +81,10 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: | |||
pc_2 = self.model_config["DECODE_MLP_W2_PRG_CONFIG"] | |||
pc_3 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] | |||
else: # Update the program configs based for prefill | |||
if seq_len >= 1024: | |||
prefill_len_cutoff = 512 if self.args.arch_name == "blackhole" else 1024 | |||
if seq_len >= prefill_len_cutoff: | |||
# Reshape input to to fit on device and parallelize computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think about making prefil_len_cutoff
an attribute of model_config
so this calculation isn't duplicated between MLP and model_config?
@@ -404,15 +404,6 @@ int main() { | |||
uint8_t prev_noc_mode = DM_DEDICATED_NOC; | |||
trigger_sync_register_init(); | |||
|
|||
|
|||
#if defined(ARCH_BLACKHOLE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR with the proper fix to handle this is still missing some cases so I think this revert can be included since we want demo changes in main.
d896b1f
to
5e2355a
Compare
…hole" This reverts commit 57ba436.
eff47f0
to
abf91cc
Compare
Issue: #18135
This PR adds initial support for Llama3-8B on BH.
Changes to
brisc.cc
will be dropped before I merge. That deals with an hang but I think it's going to be fixed in a separate commit.Main things that need review:
CI Tests