Skip to content

Commit

Permalink
#17536: add extra matmul output mem config check, doc, and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
bbradelTT committed Mar 6, 2025
1 parent 0e8d833 commit 3523ad7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
10 changes: 7 additions & 3 deletions models/demos/llama3/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:

# In decode mode (seqlen <= 32) do DRAM sharded matmuls
# These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4
out_memory_config = x.memory_config()
out_memory_config.shard_spec = None
w1_out = ttnn.linear(
x,
self.w1,
Expand All @@ -101,7 +103,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None,
dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16,
program_config=pc_1,
memory_config=x.memory_config(),
memory_config=out_memory_config,
)

w3_out = ttnn.linear(
Expand All @@ -115,7 +117,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None,
dtype=ttnn.bfloat16,
program_config=pc_3,
memory_config=x.memory_config(),
memory_config=out_memory_config,
)
ttnn.deallocate(x)

Expand Down Expand Up @@ -192,6 +194,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
if mode == "decode":
w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG)

w2_out_memory_config = w2_in.memory_config()
w2_out_memory_config.shard_spec = None
w2_out = ttnn.linear(
w2_in,
self.w2,
Expand All @@ -201,7 +205,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
memory_config=(
(ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG)
if TG
else w2_in.memory_config()
else w2_out_memory_config
),
core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None,
)
Expand Down
14 changes: 12 additions & 2 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,10 +1426,10 @@ void Matmul::validate(

TT_FATAL(optional_input_tensors.size() == 1, "Error");

const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0);
if (is_optional_output_tensor) {
const auto& optional_output_tensor_c = optional_output_tensors.at(0);
const auto& optional_output_tensor_shape = optional_output_tensor_c->get_logical_shape();
const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0);
TT_FATAL(
optional_output_tensor_shape == output_tensor_spec.logical_shape(),
"Shape of Optional Output Tensor {} doesnt match Output Tensor {}",
Expand All @@ -1447,12 +1447,22 @@ void Matmul::validate(
optional_output_tensor_c->memory_config(),
this->output_mem_config);
} else if (this->output_mem_config.shard_spec.has_value()) {
const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0);
TT_FATAL(
output_tensor_spec.memory_config() == this->output_mem_config,
"Mismatch between computed {} and provided {} mem config",
output_tensor_spec.memory_config(),
this->output_mem_config);
} else {
TT_FATAL(
output_tensor_spec.memory_config().memory_layout == this->output_mem_config.memory_layout,
"Mismatch between computed {} and provided {} mem config memory layout",
output_tensor_spec.memory_config().memory_layout,
this->output_mem_config.memory_layout);
TT_FATAL(
output_tensor_spec.memory_config().buffer_type == this->output_mem_config.buffer_type,
"Mismatch between computed {} and provided {} mem config buffer type",
output_tensor_spec.memory_config().buffer_type,
this->output_mem_config.buffer_type);
}

TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated");
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ void py_module(py::module& module) {
- Note: there are various additional constraints related to specific program
configs chosen. Please look at the error messages carefully and fix
problems appropriately.
- Note: if memory_config is provided, it will be compared to what matmul wants to use. The shard spec can be set to None to not compare it.
- Note: If optional output tensor is specified, then dtype and memory config need to be checked as follows:
- if they are default then they should be set based on optional output tensor
- if the are not default then they should be compared and if there is a difference an error is reported
Expand Down

0 comments on commit 3523ad7

Please sign in to comment.