We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Depends on #18730
tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py
"num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype", [ ( 8, 1, [1, 1, 32, 16 * 32], 3, ttnn.TILE_LAYOUT, 1024, 2, ttnn.bfloat16, ), ( 8, 1, [1, 1, 128, 128 * 32], 3, ttnn.TILE_LAYOUT, 1024, 16, ttnn.bfloat16, ), ( 8, 1, [1, 1, 32, 1024 * 16], 3, ttnn.TILE_LAYOUT, 1024, 16, # NOTE: 64 for some reason gives lower perf ttnn.bfloat16, ), ( 8, 1, [1, 1, 1024, 1024 * 32], 3, ttnn.TILE_LAYOUT, 1024, 16, ttnn.bfloat16, ), ( # AllGather + Fused QKV Matmul llama 2k prefill 8, 1, [1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT, 1280, 8, ttnn.bfloat16, ), ( # AllGather + FF1 Matmul llama 1k prefill 8, 1, [1, 1, 1024, 8192], 3, ttnn.TILE_LAYOUT, 4096, 4, ttnn.bfloat16, ),
"num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype", [ ( 8, 1, [1, 1, 32, 16 * 32], 3, ttnn.TILE_LAYOUT, 1024, 2, ttnn.bfloat16, ), ( # Llama decode FF1 8, 1, [1, 1, 32, 1024 * 8], 3, ttnn.TILE_LAYOUT, 4096, 2, # TODO: update ttnn.bfloat4_b, ), ( # Llama decode Fused QKV 8, 1, [1, 1, 32, 1024 * 8], 3, ttnn.TILE_LAYOUT, 1280, 2, # TODO: update ttnn.bfloat4_b, ), ],
"num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype", [ ( # Llama decode Selfout 8, 1, [1, 1, 32, 1024 * 8], 3, ttnn.TILE_LAYOUT, 1024, 8, ttnn.bfloat8_b, ), ( 8, 1, [1, 1, 32, 1024 * 8], 3, ttnn.TILE_LAYOUT, 1024, 32, ttnn.bfloat8_b, ), ], ) @pytest.mark.parametrize( "ag_input_dtype", [ ttnn.bfloat16, ], ) @pytest.mark.parametrize( "mem_config_input, mem_config_ag, mem_config_mm, mem_config_weights", [ ( ttnn.MemoryConfig( ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, ttnn.ShardSpec( ttnn.CoreRangeSet( { ttnn.CoreRange( ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0), ), } ), [ 32, # shard_height 128, # shard width ], ttnn.ShardOrientation.ROW_MAJOR, ), ), ttnn.MemoryConfig( ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, ttnn.ShardSpec( ttnn.CoreRangeSet( { ttnn.CoreRange( ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0), ), } ), [ 32, # shard_height 8192 // 8, # shard_width_hidden_dim_across_8_cores ], ttnn.ShardOrientation.ROW_MAJOR, ), ), ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1), ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), ) ], ids=("llama_selfout",), )
The text was updated successfully, but these errors were encountered:
SeanNijjar
No branches or pull requests
Depends on #18730
Reference Test File:
tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py
Cases with matmul 2D:
Cases with Matmul 1D
Llama shapes:
The text was updated successfully, but these errors were encountered: