Skip to content

Commit

Permalink
Fix regression on conv1d auto-shard
Browse files Browse the repository at this point in the history
Add conv1d pytorch sweep to nightly test,
to prevent further regressions.
  • Loading branch information
Pavle Josipovic committed Mar 6, 2025
1 parent 769d3d0 commit c3b830e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 5 deletions.
16 changes: 16 additions & 0 deletions tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,3 +1636,19 @@ def test_conv2d_localrun_fail_only(device, input_spec):
device,
)[0]
assert pcc, messsage


@pytest.mark.parametrize("input_spec", parameters["short_sweep_suite_conv1d"]["input_specs"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_conv2d_localrun_conv1d(device, input_spec):
pcc, messsage = run_conv1d_short_sweep(
input_spec,
device,
)[0]
assert pcc, messsage


failing_parameters_conv1d = [
# [batch_size, output_channels, input_channels, input_length, kernel_size, stride, pad, groups, dilation, bias]
[1, 768, 768, 3000, 3, 2, 1, 1, 1, True],
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from tests.sweep_framework.sweep_utils.conv2d_common import run_conv1d_short_sweep
from tests.sweep_framework.sweeps.conv2d.short.conv2d_short_sweep import parameters as parameters_ttnn_pytorch
from tests.sweep_framework.sweeps.conv2d.short.conv2d_short_sweep import (
failing_parameters_conv1d as failing_parameters_ttnn_pytorch,
)

from models.utility_functions import (
is_wormhole_b0,
)

import pytest


@pytest.mark.parametrize("input_spec", parameters_ttnn_pytorch["short_sweep_suite_conv1d"]["input_specs"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_ttnn_pytorch_sweep(device, input_spec):
if device.core_grid.y != 8 and is_wormhole_b0():
pytest.skip("Needs 8x8 grid for wormhole_b0")

# Check if input_spec is in failing_parameters
if input_spec in failing_parameters_ttnn_pytorch:
pytest.skip(f"Skipping test for failing input_spec: {input_spec}")

pcc, messsage = run_conv1d_short_sweep(
input_spec,
device,
)[0]
assert pcc, messsage
10 changes: 6 additions & 4 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,12 @@ bool use_matmul_for_1x1_conv(
padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 && groups == 1 && (not is_width_sharded);
}

bool is_1d_conv(uint32_t kernel_width, uint32_t image_width) { return kernel_width == 1 && image_width == 1; }

bool is_1d_deptwise_conv(
uint32_t groups, uint32_t input_channels, uint32_t output_channels, uint32_t kernel_width, uint32_t image_width) {
bool is_depthwise_conv = groups == input_channels && groups == output_channels;
bool is_conv1d = kernel_width == 1 && image_width == 1;
return is_depthwise_conv && is_conv1d;
return is_depthwise_conv && is_1d_conv(kernel_width, image_width);
}

template <typename DeviceType>
Expand Down Expand Up @@ -711,6 +712,7 @@ Conv2dConfig determine_conv_config_for_auto_shard(
Conv2dConfig conv_config;
};

const bool conv_is_1d = is_1d_conv(kernel_size[1], input_width);
const bool conv_is_1d_deptwise =
is_1d_deptwise_conv(groups, in_channels, out_channels, kernel_size[1], input_width);

Expand Down Expand Up @@ -753,7 +755,7 @@ Conv2dConfig determine_conv_config_for_auto_shard(
compute_grid_size,
shard_orientation,
!is_mm_conv,
is_out_tiled,
true,
conv_config.act_block_h_override);

const ParallelConfig output_parallel_config = determine_output_parallel_config(
Expand Down Expand Up @@ -817,7 +819,7 @@ Conv2dConfig determine_conv_config_for_auto_shard(
core_count_and_size height = get_l1_usage_for_sharding(TensorMemoryLayout::HEIGHT_SHARDED, conv_config);

// 1d deptwise convs support only height sharding
if (conv_is_1d_deptwise) {
if (conv_is_1d) {
return height.conv_config;
}

Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ bool use_matmul_for_1x1_conv(
uint32_t groups,
const Conv2dConfig& conv_config);

bool is_1d_conv(uint32_t kernel_width, uint32_t image_width);

bool is_1d_deptwise_conv(
uint32_t groups, uint32_t input_channels, uint32_t output_channels, uint32_t kernel_width, uint32_t image_width);
sliding_window::ParallelConfig determine_parallel_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_

uint32_t input_width = ashape[2];
uint32_t input_channels = ashape[3];
bool is_conv1d = filter_w == 1 && input_width == 1;
bool is_conv1d = is_1d_conv(filter_w, input_width);
bool is_conv_1d_depthwise_conv =
is_1d_deptwise_conv(groups, input_channels, output_channels, filter_w, input_width);

Expand Down

0 comments on commit c3b830e

Please sign in to comment.