Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for Llama3-8B on BH #18697

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Conversation

mtairum
Copy link
Contributor

@mtairum mtairum commented Mar 6, 2025

Issue: #18135

This PR adds initial support for Llama3-8B on BH.

Changes to brisc.cc will be dropped before I merge. That deals with an hang but I think it's going to be fixed in a separate commit.

Main things that need review:

  • conftest (updated the mesh_device to work on BH)
  • Model tests on BH-Post-Commit (right now, only llama-8B should run) (testing to see if it's working as intended 😅 )

CI Tests

@mtairum mtairum force-pushed the mtairum/llama3-blackhole branch from e3c9847 to e006a3c Compare March 6, 2025 15:56
@mtairum mtairum changed the title Add support for Llama3-8B on BH Add initial support for Llama3-8B on BH Mar 6, 2025
@mtairum mtairum self-assigned this Mar 6, 2025
@mtairum mtairum marked this pull request as ready for review March 6, 2025 18:03
Copy link
Contributor

@cglagovichTT cglagovichTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with minor comments

@@ -118,7 +122,11 @@ def count_devices(output):

# Count total PCI devices (ignoring N/A)
total_pci_devices = len(
[line for line in available_boards.split("\n") if ("Wormhole" or "Grayskull" or "Blackhole") in line]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're here, want to remove Grayskull?

@@ -81,9 +81,10 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
pc_2 = self.model_config["DECODE_MLP_W2_PRG_CONFIG"]
pc_3 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"]
else: # Update the program configs based for prefill
if seq_len >= 1024:
prefill_len_cutoff = 512 if self.args.arch_name == "blackhole" else 1024
if seq_len >= prefill_len_cutoff:
# Reshape input to to fit on device and parallelize computation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about making prefil_len_cutoff an attribute of model_config so this calculation isn't duplicated between MLP and model_config?

@@ -404,15 +404,6 @@ int main() {
uint8_t prev_noc_mode = DM_DEDICATED_NOC;
trigger_sync_register_init();


#if defined(ARCH_BLACKHOLE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR with the proper fix to handle this is still missing some cases so I think this revert can be included since we want demo changes in main.

cc @jbaumanTT @tt-aho

@mtairum mtairum force-pushed the mtairum/llama3-blackhole branch from d896b1f to 5e2355a Compare March 7, 2025 14:08
@mtairum mtairum force-pushed the mtairum/llama3-blackhole branch from eff47f0 to abf91cc Compare March 7, 2025 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants