Skip to content

Commit

Permalink
#0: separate validation of conv weight and bias. (#15990)
Browse files Browse the repository at this point in the history
### Ticket
#0

### Problem description
Validate weight and bias separately. Conv Assumes if weight is on the
host then the bias will be on Host too.

### What's changed
Deleted the validate_weight_and_bias function. Validating weight and
bias separately.
  • Loading branch information
shwetankTT authored Dec 18, 2024
1 parent 5d0170e commit 999327f
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 67 deletions.
2 changes: 1 addition & 1 deletion models/demos/convnet_mnist/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf):
subdir = "ttnn_convnet_mnist"
num_iterations = 1
margin = 0.03
expected_perf = 1753.5 if is_grayskull() else 2705.5
expected_perf = 1800 if is_grayskull() else 2800.5

command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
2 changes: 2 additions & 0 deletions models/demos/vgg/tt/ttnn_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def ttnn_vgg16(
tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight
tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT)
tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias
tt_bias = ttnn.to_layout(ttnn.from_device(tt_bias), layout=ttnn.ROW_MAJOR_LAYOUT)
# Call ttnn.conv
conv_op_cache = {}
[tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
Expand Down Expand Up @@ -242,6 +243,7 @@ def ttnn_vgg11(
tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight
tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT)
tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias
tt_bias = ttnn.to_layout(ttnn.from_device(tt_bias), layout=ttnn.ROW_MAJOR_LAYOUT)

# Call ttnn.conv
conv_op_cache = {}
Expand Down
131 changes: 130 additions & 1 deletion tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,141 @@ def test_prepare_conv_weights(
compute_config=compute_config,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)
torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape)

pcc = 0.99
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing


@skip_for_grayskull()
@skip_for_blackhole()
# @skip_for_wormhole_b0()
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override",
(
# rn50 layer1
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
),
)
@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"])
@pytest.mark.parametrize("has_bias", [True, False], ids=["has_bias", "no_bias"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 2**15}], indirect=True)
def test_prepare_bias(
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
packer_l1_acc,
config_override,
has_bias,
device,
):
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

if batch_size == 20 and (
output_channels == 64 or (stride_h == 2 and (output_channels == 256 or output_channels == 128))
):
pytest.skip("Skipping test because it won't fit in L1!")

inp_shape = (batch_size, input_channels, input_height, input_width)
conv_weight_shape = (output_channels, input_channels, filter_height, filter_width)
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16)
torch_input_tensor = torch.randn(inp_shape, dtype=torch.bfloat16)
torch_bias_tensor = torch.randn((1, 1, 1, output_channels), dtype=torch.bfloat16) if has_bias else None

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(1, 1),
groups=1,
).permute(0, 2, 3, 1)

tt_input_tensor = ttnn.from_torch(torch_input_tensor.transpose(-3, -2).transpose(-2, -1), ttnn.bfloat16)
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16) if has_bias else None

conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16,
input_channels_alignment=(16 if input_channels == 16 and input_height == 115 else 32),
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
)
compute_config = ttnn.init_device_compute_kernel_config(device.arch(), packer_l1_acc=packer_l1_acc)
if config_override and "act_block_h" in config_override:
conv_config.act_block_h_override = config_override["act_block_h"]

if config_override and "act_block_w_div" in config_override:
conv_config.act_block_w_div = config_override["act_block_w_div"]

if config_override and "num_cores_nhw" in config_override:
if config_override["num_cores_nhw"] == 98:
conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (11, 7)), ttnn.CoreRange((0, 8), (1, 8))})
conv_config.override_sharding_config = True
print("Setting num_cores_nhw to 98")

conv_kwargs = {
"input_layout": ttnn.ROW_MAJOR_LAYOUT,
"in_channels": input_channels,
"out_channels": output_channels,
"batch_size": batch_size,
"input_height": input_height,
"input_width": input_width,
"kernel_size": (filter_height, filter_width),
"stride": (stride_h, stride_w),
"padding": (pad_h, pad_w),
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": conv_config,
}

tt_input_tensor = ttnn.to_device(tt_input_tensor, device)

tt_bias_tensor_formatted = (
ttnn.prepare_conv_bias(
bias_tensor=tt_bias_tensor, input_memory_config=tt_input_tensor.memory_config(), **conv_kwargs
)
if has_bias
else None
)

tt_bias_tensor_formatted = ttnn.to_device(tt_bias_tensor_formatted, device) if has_bias else None
(k := next(iter(conv_kwargs)), conv_kwargs.pop(k)) ##removing 1st element from dict
tt_output_tensor_on_device = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
bias_tensor=tt_bias_tensor_formatted,
**conv_kwargs,
compute_config=compute_config,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape)
#

pcc = 0.99
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
Expand Down
7 changes: 1 addition & 6 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,7 @@ Result conv2d(

ShardOrientation shard_orientation =
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;
auto num_cores_c = shard_orientation == ShardOrientation::COL_MAJOR ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x;
auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2;
bool is_non_tile_mul_width =
(conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 &&
(conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) &&
conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0;
bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels);

DeviceComputeKernelConfig compute_config = compute_config_.value_or(
init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false));
Expand Down
Loading

0 comments on commit 999327f

Please sign in to comment.