diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index 4ea55b8865b..664b080e92c 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -90,6 +90,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: # In decode mode (seqlen <= 32) do DRAM sharded matmuls # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 + out_memory_config = x.memory_config() + out_memory_config.shard_spec = None w1_out = ttnn.linear( x, self.w1, @@ -101,7 +103,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16, program_config=pc_1, - memory_config=x.memory_config(), + memory_config=out_memory_config, ) w3_out = ttnn.linear( @@ -115,7 +117,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, dtype=ttnn.bfloat16, program_config=pc_3, - memory_config=x.memory_config(), + memory_config=out_memory_config, ) ttnn.deallocate(x) @@ -192,6 +194,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: if mode == "decode": w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) + w2_out_memory_config = w2_in.memory_config() + w2_out_memory_config.shard_spec = None w2_out = ttnn.linear( w2_in, self.w2, @@ -201,7 +205,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: memory_config=( (ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG) if TG - else w2_in.memory_config() + else w2_out_memory_config ), core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, ) diff --git a/tests/ttnn/unit_tests/operations/test_experimental.py b/tests/ttnn/unit_tests/operations/test_experimental.py index a8e76c120fa..f30ea9f01bf 100644 --- a/tests/ttnn/unit_tests/operations/test_experimental.py +++ b/tests/ttnn/unit_tests/operations/test_experimental.py @@ -169,6 +169,11 @@ def test_ttnn_matmul_dram_sharded(device, m_size, k_size, n_size): shard_shape = (32, 1024) shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, shard_spec) + out_shard_shape = (32, 128) + out_shard_spec = ttnn.ShardSpec(shard_grid, out_shard_shape, ttnn.ShardOrientation.ROW_MAJOR) + out_sharded_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, out_shard_spec + ) input_tensor_in0 = ttnn.to_memory_config(input_tensor_in0, sharded_mem_config) # in1 shard config @@ -203,7 +208,7 @@ def test_ttnn_matmul_dram_sharded(device, m_size, k_size, n_size): input_tensor_in0, input_tensor_in1, program_config=program_config, - memory_config=sharded_mem_config, + memory_config=out_sharded_mem_config, dtype=ttnn.bfloat16, compute_kernel_config=compute_kernel_config, ) diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index 1bb4cb64bf6..cc775e2ec53 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -2089,7 +2089,7 @@ def test_interleaved_input_sharded_output_matmul(device): out_mem_config = ttnn.create_sharded_memory_config( shape=(32, 256), - core_grid=ttnn.CoreGrid(x=1, y=8), + core_grid=ttnn.CoreGrid(x=8, y=1), strategy=ttnn.ShardStrategy.WIDTH, orientation=ttnn.ShardOrientation.ROW_MAJOR, ) @@ -2100,8 +2100,8 @@ def test_interleaved_input_sharded_output_matmul(device): # Block sharded out_mem_config = ttnn.create_sharded_memory_config( - shape=(32, 256), - core_grid=ttnn.CoreGrid(x=1, y=8), + shape=(256, 256), + core_grid=ttnn.CoreGrid(x=1, y=1), strategy=ttnn.ShardStrategy.BLOCK, orientation=ttnn.ShardOrientation.ROW_MAJOR, ) diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 3d72eb55267..f86614e667f 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1426,10 +1426,10 @@ void Matmul::validate( TT_FATAL(optional_input_tensors.size() == 1, "Error"); + const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0); if (is_optional_output_tensor) { const auto& optional_output_tensor_c = optional_output_tensors.at(0); const auto& optional_output_tensor_shape = optional_output_tensor_c->get_logical_shape(); - const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0); TT_FATAL( optional_output_tensor_shape == output_tensor_spec.logical_shape(), "Shape of Optional Output Tensor {} doesnt match Output Tensor {}", @@ -1446,6 +1446,23 @@ void Matmul::validate( "tensor {}", optional_output_tensor_c->memory_config(), this->output_mem_config); + } else { + TT_FATAL( + output_tensor_spec.memory_config().memory_layout == this->output_mem_config.memory_layout, + "Mismatch between computed {} and provided {} mem config memory layout", + output_tensor_spec.memory_config().memory_layout, + this->output_mem_config.memory_layout); + TT_FATAL( + output_tensor_spec.memory_config().buffer_type == this->output_mem_config.buffer_type, + "Mismatch between computed {} and provided {} mem config buffer type", + output_tensor_spec.memory_config().buffer_type, + this->output_mem_config.buffer_type); + // Needs to be an assert, there are too many existing models not specifying shard spec correctly. + TT_ASSERT( + output_tensor_spec.memory_config() == this->output_mem_config, + "Mismatch between computed {} and provided {} mem config", + output_tensor_spec.memory_config(), + this->output_mem_config); } TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated"); diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index 5dff12923c7..d729a5f5247 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -270,6 +270,7 @@ void py_module(py::module& module) { - Note: there are various additional constraints related to specific program configs chosen. Please look at the error messages carefully and fix problems appropriately. + - Note: if memory_config is provided, it will be compared to what matmul wants to use. The shard spec can be set to None to not compare it. - Note: If optional output tensor is specified, then dtype and memory config need to be checked as follows: - if they are default then they should be set based on optional output tensor - if the are not default then they should be compared and if there is a difference an error is reported