Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port All-Gather Matmul to be Async, Use Fabric #18740

Open
2 tasks
SeanNijjar opened this issue Mar 6, 2025 · 0 comments
Open
2 tasks

Port All-Gather Matmul to be Async, Use Fabric #18740

SeanNijjar opened this issue Mar 6, 2025 · 0 comments
Assignees
Labels
llm_t3000 P1 perf for issues tracking performance problems/improvements

Comments

@SeanNijjar
Copy link
Contributor

SeanNijjar commented Mar 6, 2025

  • Port program factory
  • Thread synchronization from all-gather to matul
    • This is a good opportunity to leverage the output async mode and redirect the output sync to the matmul

Depends on #18730

Reference Test File:

tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py

  • All test cases should pass before this work-item is marked complete

Cases with matmul 2D:

    "num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype",
    [
        (
            8,
            1,
            [1, 1, 32, 16 * 32],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            2,
            ttnn.bfloat16,
        ),
        (
            8,
            1,
            [1, 1, 128, 128 * 32],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            16,
            ttnn.bfloat16,
        ),
        (
            8,
            1,
            [1, 1, 32, 1024 * 16],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            16,  # NOTE: 64 for some reason gives lower perf
            ttnn.bfloat16,
        ),
        (
            8,
            1,
            [1, 1, 1024, 1024 * 32],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            16,
            ttnn.bfloat16,
        ),
        (  # AllGather + Fused QKV Matmul llama 2k prefill
            8,
            1,
            [1, 1, 2048, 8192],
            3,
            ttnn.TILE_LAYOUT,
            1280,
            8,
            ttnn.bfloat16,
        ),
        (  # AllGather + FF1 Matmul llama 1k prefill
            8,
            1,
            [1, 1, 1024, 8192],
            3,
            ttnn.TILE_LAYOUT,
            4096,
            4,
            ttnn.bfloat16,
        ),

Cases with Matmul 1D

    "num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype",
    [
        (
            8,
            1,
            [1, 1, 32, 16 * 32],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            2,
            ttnn.bfloat16,
        ),
        (  # Llama decode FF1
            8,
            1,
            [1, 1, 32, 1024 * 8],
            3,
            ttnn.TILE_LAYOUT,
            4096,
            2,  # TODO: update
            ttnn.bfloat4_b,
        ),
        (  # Llama decode Fused QKV
            8,
            1,
            [1, 1, 32, 1024 * 8],
            3,
            ttnn.TILE_LAYOUT,
            1280,
            2,  # TODO: update
            ttnn.bfloat4_b,
        ),
    ],

Llama shapes:

    "num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype",
    [
        (  # Llama decode Selfout
            8,
            1,
            [1, 1, 32, 1024 * 8],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            8,
            ttnn.bfloat8_b,
        ),
        (
            8,
            1,
            [1, 1, 32, 1024 * 8],
            3,
            ttnn.TILE_LAYOUT,
            1024,
            32,
            ttnn.bfloat8_b,
        ),
    ],
)
@pytest.mark.parametrize(
    "ag_input_dtype",
    [
        ttnn.bfloat16,
    ],
)
@pytest.mark.parametrize(
    "mem_config_input, mem_config_ag, mem_config_mm, mem_config_weights",
    [
        (
            ttnn.MemoryConfig(
                ttnn.TensorMemoryLayout.WIDTH_SHARDED,
                ttnn.BufferType.L1,
                ttnn.ShardSpec(
                    ttnn.CoreRangeSet(
                        {
                            ttnn.CoreRange(
                                ttnn.CoreCoord(0, 0),
                                ttnn.CoreCoord(7, 0),
                            ),
                        }
                    ),
                    [
                        32,  # shard_height
                        128,  # shard width
                    ],
                    ttnn.ShardOrientation.ROW_MAJOR,
                ),
            ),
            ttnn.MemoryConfig(
                ttnn.TensorMemoryLayout.WIDTH_SHARDED,
                ttnn.BufferType.L1,
                ttnn.ShardSpec(
                    ttnn.CoreRangeSet(
                        {
                            ttnn.CoreRange(
                                ttnn.CoreCoord(0, 0),
                                ttnn.CoreCoord(7, 0),
                            ),
                        }
                    ),
                    [
                        32,  # shard_height
                        8192 // 8,  # shard_width_hidden_dim_across_8_cores
                    ],
                    ttnn.ShardOrientation.ROW_MAJOR,
                ),
            ),
            ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1),
            ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
        )
    ],
    ids=("llama_selfout",),
)
@SeanNijjar SeanNijjar added llm_t3000 P1 perf for issues tracking performance problems/improvements labels Mar 6, 2025
@SeanNijjar SeanNijjar self-assigned this Mar 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llm_t3000 P1 perf for issues tracking performance problems/improvements
Projects
None yet
Development

No branches or pull requests

1 participant