From 3523ad74358cbfec1ab434fa79b05238185c8eeb Mon Sep 17 00:00:00 2001 From: Borys Bradel Date: Thu, 6 Mar 2025 17:15:39 +0000 Subject: [PATCH] #17536: add extra matmul output mem config check, doc, and fix test --- models/demos/llama3/tt/llama_mlp.py | 10 +++++++--- .../ttnn/operations/matmul/device/matmul_op.cpp | 14 ++++++++++++-- ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp | 1 + 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index 4ea55b8865b..664b080e92c 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -90,6 +90,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: # In decode mode (seqlen <= 32) do DRAM sharded matmuls # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 + out_memory_config = x.memory_config() + out_memory_config.shard_spec = None w1_out = ttnn.linear( x, self.w1, @@ -101,7 +103,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16, program_config=pc_1, - memory_config=x.memory_config(), + memory_config=out_memory_config, ) w3_out = ttnn.linear( @@ -115,7 +117,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, dtype=ttnn.bfloat16, program_config=pc_3, - memory_config=x.memory_config(), + memory_config=out_memory_config, ) ttnn.deallocate(x) @@ -192,6 +194,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: if mode == "decode": w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) + w2_out_memory_config = w2_in.memory_config() + w2_out_memory_config.shard_spec = None w2_out = ttnn.linear( w2_in, self.w2, @@ -201,7 +205,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: memory_config=( (ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG) if TG - else w2_in.memory_config() + else w2_out_memory_config ), core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, ) diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 7df4dbbb38d..8e48d2bec19 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1426,10 +1426,10 @@ void Matmul::validate( TT_FATAL(optional_input_tensors.size() == 1, "Error"); + const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0); if (is_optional_output_tensor) { const auto& optional_output_tensor_c = optional_output_tensors.at(0); const auto& optional_output_tensor_shape = optional_output_tensor_c->get_logical_shape(); - const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0); TT_FATAL( optional_output_tensor_shape == output_tensor_spec.logical_shape(), "Shape of Optional Output Tensor {} doesnt match Output Tensor {}", @@ -1447,12 +1447,22 @@ void Matmul::validate( optional_output_tensor_c->memory_config(), this->output_mem_config); } else if (this->output_mem_config.shard_spec.has_value()) { - const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0); TT_FATAL( output_tensor_spec.memory_config() == this->output_mem_config, "Mismatch between computed {} and provided {} mem config", output_tensor_spec.memory_config(), this->output_mem_config); + } else { + TT_FATAL( + output_tensor_spec.memory_config().memory_layout == this->output_mem_config.memory_layout, + "Mismatch between computed {} and provided {} mem config memory layout", + output_tensor_spec.memory_config().memory_layout, + this->output_mem_config.memory_layout); + TT_FATAL( + output_tensor_spec.memory_config().buffer_type == this->output_mem_config.buffer_type, + "Mismatch between computed {} and provided {} mem config buffer type", + output_tensor_spec.memory_config().buffer_type, + this->output_mem_config.buffer_type); } TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated"); diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index 5dff12923c7..d729a5f5247 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -270,6 +270,7 @@ void py_module(py::module& module) { - Note: there are various additional constraints related to specific program configs chosen. Please look at the error messages carefully and fix problems appropriately. + - Note: if memory_config is provided, it will be compared to what matmul wants to use. The shard spec can be set to None to not compare it. - Note: If optional output tensor is specified, then dtype and memory config need to be checked as follows: - if they are default then they should be set based on optional output tensor - if the are not default then they should be compared and if there is a difference an error is reported