Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17536: add matmul output mem config validation #18721

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions models/demos/llama3/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:

# In decode mode (seqlen <= 32) do DRAM sharded matmuls
# These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4
out_memory_config = x.memory_config()
out_memory_config.shard_spec = None
w1_out = ttnn.linear(
x,
self.w1,
Expand All @@ -101,7 +103,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None,
dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16,
program_config=pc_1,
memory_config=x.memory_config(),
memory_config=out_memory_config,
)

w3_out = ttnn.linear(
Expand All @@ -115,7 +117,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None,
dtype=ttnn.bfloat16,
program_config=pc_3,
memory_config=x.memory_config(),
memory_config=out_memory_config,
)
ttnn.deallocate(x)

Expand Down Expand Up @@ -192,6 +194,8 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
if mode == "decode":
w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG)

w2_out_memory_config = w2_in.memory_config()
w2_out_memory_config.shard_spec = None
w2_out = ttnn.linear(
w2_in,
self.w2,
Expand All @@ -201,7 +205,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
memory_config=(
(ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG)
if TG
else w2_in.memory_config()
else w2_out_memory_config
),
core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None,
)
Expand Down
7 changes: 6 additions & 1 deletion tests/ttnn/unit_tests/operations/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def test_ttnn_matmul_dram_sharded(device, m_size, k_size, n_size):
shard_shape = (32, 1024)
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, shard_spec)
out_shard_shape = (32, 128)
out_shard_spec = ttnn.ShardSpec(shard_grid, out_shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, out_shard_spec
)
input_tensor_in0 = ttnn.to_memory_config(input_tensor_in0, sharded_mem_config)

# in1 shard config
Expand Down Expand Up @@ -203,7 +208,7 @@ def test_ttnn_matmul_dram_sharded(device, m_size, k_size, n_size):
input_tensor_in0,
input_tensor_in1,
program_config=program_config,
memory_config=sharded_mem_config,
memory_config=out_sharded_mem_config,
dtype=ttnn.bfloat16,
compute_kernel_config=compute_kernel_config,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,7 @@ def test_interleaved_input_sharded_output_matmul(device):

out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
core_grid=ttnn.CoreGrid(x=8, y=1),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
Expand All @@ -2100,8 +2100,8 @@ def test_interleaved_input_sharded_output_matmul(device):

# Block sharded
out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
shape=(256, 256),
core_grid=ttnn.CoreGrid(x=1, y=1),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
Expand Down
20 changes: 19 additions & 1 deletion ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,10 +1426,10 @@ void Matmul::validate(

TT_FATAL(optional_input_tensors.size() == 1, "Error");

const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0);
if (is_optional_output_tensor) {
const auto& optional_output_tensor_c = optional_output_tensors.at(0);
const auto& optional_output_tensor_shape = optional_output_tensor_c->get_logical_shape();
const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}, optional_input_tensors).at(0);
TT_FATAL(
optional_output_tensor_shape == output_tensor_spec.logical_shape(),
"Shape of Optional Output Tensor {} doesnt match Output Tensor {}",
Expand All @@ -1446,6 +1446,24 @@ void Matmul::validate(
"tensor {}",
optional_output_tensor_c->memory_config(),
this->output_mem_config);
} else if (this->output_mem_config.shard_spec.has_value()) {
TT_FATAL(
output_tensor_spec.memory_config() == this->output_mem_config,
"Mismatch between computed {} and provided {} mem config",
output_tensor_spec.memory_config(),
this->output_mem_config);
} else {
// TODO: try to change these to fatals and fix test_llama_model.py in APC
TT_ASSERT(
output_tensor_spec.memory_config().memory_layout == this->output_mem_config.memory_layout,
"Mismatch between computed {} and provided {} mem config memory layout",
output_tensor_spec.memory_config().memory_layout,
this->output_mem_config.memory_layout);
TT_ASSERT(
output_tensor_spec.memory_config().buffer_type == this->output_mem_config.buffer_type,
"Mismatch between computed {} and provided {} mem config buffer type",
output_tensor_spec.memory_config().buffer_type,
this->output_mem_config.buffer_type);
}

TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated");
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ void py_module(py::module& module) {
- Note: there are various additional constraints related to specific program
configs chosen. Please look at the error messages carefully and fix
problems appropriately.
- Note: if memory_config is provided, it will be compared to what matmul wants to use. The shard spec can be set to None to not compare it.
- Note: If optional output tensor is specified, then dtype and memory config need to be checked as follows:
- if they are default then they should be set based on optional output tensor
- if the are not default then they should be compared and if there is a difference an error is reported
Expand Down
Loading