Skip to content

Commit

Permalink
Add test case and allow user to choose whther to enable split reader …
Browse files Browse the repository at this point in the history
…or not for conv2d.

Signed-off-by: Nilaykumar Patel <nkpatel@tenstorrent.com>
  • Loading branch information
nkpatel-tt committed Feb 28, 2025
1 parent 7c64e91 commit 1da69fe
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 15 deletions.
81 changes: 81 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def run_conv(
weight_mesh_mapper=None,
output_mesh_composer=None,
enable_split_reader=False,
enable_halo_split_reader=False,
activation="",
):
if isinstance(device, ttnn.MeshDevice):
Expand Down Expand Up @@ -138,6 +139,7 @@ def run_conv(
enable_subblock_padding=False,
output_layout=output_layout,
activation=activation,
enable_halo_split_reader=enable_halo_split_reader,
)
compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
Expand Down Expand Up @@ -2852,3 +2854,82 @@ def test_block_sharding_relu_act_block_h(
shard_layout=shard_layout,
activation=activation,
)

@pytest.mark.parametrize("batch", [1])
@pytest.mark.parametrize(
"output_channels, input_channels, input_height, input_width",
(
(4, 32, 288, 288),
(32, 48, 284, 284),
(48, 56, 280, 280),
(56, 64, 272, 272),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize(
"kernel, dilation, padding",
[
[5, 2, 2],
[3, 8, 1],
],
)
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize("enable_halo_split_reader", [True])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384*2}], indirect=True)
def test_halo_split_reader(
device,
torch_tensor_map,
batch,
output_channels,
input_channels,
input_height,
input_width,
weights_dtype,
activations_dtype,
math_fidelity,
kernel,
dilation,
padding,
stride,
enable_halo_split_reader
):
config_override = {}

run_conv(
device=device,
torch_tensor_map=torch_tensor_map,
activations_dtype=activations_dtype,
weights_dtype=weights_dtype,
batch_size=batch,
output_channels=output_channels,
input_channels=input_channels,
input_height=input_height,
input_width=input_width,
filter_height=kernel,
filter_width=kernel,
stride_h=stride,
stride_w=stride,
pad_h=padding,
pad_w=padding,
config_override=config_override,
dilation=dilation,
math_fidelity=math_fidelity,
output_layout=ttnn.TILE_LAYOUT,
debug=False,
groups=1,
has_bias=True,
shard_layout=None,
memory_config=None,
input_mesh_mapper=None,
weight_mesh_mapper=None,
output_mesh_composer=None,
enable_halo_split_reader=enable_halo_split_reader,
)
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ Result conv2d(
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
0,
input_tensor_post_tm.memory_config(),
true);
true,
conv_config.enable_halo_split_reader);

if (conv_config.deallocate_activation) {
input_tensor_post_tm.deallocate(/*force*/ true);
Expand Down
5 changes: 4 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ void py_bind_conv2d(py::module& module) {
bool,
bool,
bool,
bool,
bool>(),
py::kw_only(),
py::arg("dtype") = DataType::BFLOAT16,
Expand All @@ -354,7 +355,8 @@ void py_bind_conv2d(py::module& module) {
py::arg("enable_act_double_buffer") = false,
py::arg("enable_weights_double_buffer") = false,
py::arg("enable_split_reader") = false,
py::arg("enable_subblock_padding") = false);
py::arg("enable_subblock_padding") = false,
py::arg("enable_halo_split_reader") = false);
py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype);
py_conv_config.def_readwrite("weights_dtype", &Conv2dConfig::weights_dtype);
py_conv_config.def_readwrite("activation", &Conv2dConfig::activation);
Expand All @@ -373,6 +375,7 @@ void py_bind_conv2d(py::module& module) {
py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer);
py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader);
py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding);
py_conv_config.def_readwrite("enable_halo_split_reader", &Conv2dConfig::enable_halo_split_reader);

py_conv_config.def("__repr__", [](const Conv2dConfig& config) { return fmt::format("{}", config); });

Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct Conv2dConfig {
bool enable_split_reader = false;

bool enable_subblock_padding = false;

bool enable_halo_split_reader = false;
static constexpr auto attribute_names = std::make_tuple(
"dtype",
"weights_dtype",
Expand All @@ -91,7 +93,8 @@ struct Conv2dConfig {
"enable_act_double_buffer",
"enable_weights_double_buffer",
"enable_split_reader",
"enable_subblock_padding");
"enable_subblock_padding",
"enable_halo_split_reader");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->dtype),
Expand All @@ -111,7 +114,8 @@ struct Conv2dConfig {
std::cref(this->enable_act_double_buffer),
std::cref(this->enable_weights_double_buffer),
std::cref(this->enable_split_reader),
std::cref(this->enable_subblock_padding));
std::cref(this->enable_subblock_padding),
std::cref(this->enable_halo_split_reader));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,24 @@ void copy_sticks_async(
if constexpr (enable_split_reader) {
if constexpr (is_remote_config) {
if constexpr (is_reader) {
// Skip every odd iteration for readers in remote configuration
if (iteration % 2 == 1) {
continue;
}
} else {
// Skip every even iteration for writer in remote configuration
if (iteration % 2 == 0) {
continue;
}
}
} else {
if constexpr (is_reader) {
// Skip every even iteration for readers in local configuration
if (iteration % 2 == 0) {
continue;
}
} else {
// Skip every odd iteration for writer in local configuration
if (iteration % 2 == 1) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ tt::tt_metal::operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
Tensor& output_tensor,
const bool capture_buffers, // Used by halo op to cache internally created config buffers with the program Untilize
// with Halo V2 op takes them as inputs from the user, so doesn't capture
const bool enable_split_reader = false);
const bool enable_split_reader = true);
} // namespace ttnn::operations::data_movement::detail
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program(
remote_read_,
transpose_mcast_,
output_tensor,
/*capture_buffers=*/true)};
/*capture_buffers=*/true,
enable_split_reader_)};
}

Tensor halo_op(
Expand All @@ -141,7 +142,8 @@ Tensor halo_op(
bool transpose_mcast,
uint32_t reshard_num_cores_nhw,
const MemoryConfig& output_memory_config,
bool is_out_tiled) {
bool is_out_tiled,
bool enable_split_reader) {
TT_FATAL(input_tensor.memory_config().is_sharded(), "Halo expects sharded input tensor");
TT_FATAL(
input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ||
Expand All @@ -159,7 +161,8 @@ Tensor halo_op(
transpose_mcast,
reshard_num_cores_nhw,
output_memory_config,
is_out_tiled](
is_out_tiled,
enable_split_reader](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand Down Expand Up @@ -192,7 +195,8 @@ Tensor halo_op(
.reshard_num_cores_nhw_ = reshard_num_cores_nhw,
.max_out_nsticks_per_core_ = max_out_nsticks_per_core,
.output_memory_config_ = output_memory_config,
.is_out_tiled_ = is_out_tiled},
.is_out_tiled_ = is_out_tiled,
.enable_split_reader_ = enable_split_reader},
{input_tensor});
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct HaloDeviceOperation {
uint32_t max_out_nsticks_per_core_;
MemoryConfig output_memory_config_;
bool is_out_tiled_;
bool enable_split_reader_;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
Expand All @@ -41,7 +42,8 @@ struct HaloDeviceOperation {
"reshard_num_cores_nhw_",
"max_out_nsticks_per_core_",
"output_memory_config_",
"is_out_tiled_");
"is_out_tiled_",
"enable_split_reader_");
const auto attribute_values() const {
return std::make_tuple(
std::cref(config_),
Expand All @@ -52,7 +54,8 @@ struct HaloDeviceOperation {
std::cref(reshard_num_cores_nhw_),
std::cref(max_out_nsticks_per_core_),
std::cref(output_memory_config_),
std::cref(is_out_tiled_));
std::cref(is_out_tiled_),
std::cref(enable_split_reader_));
}
};

Expand All @@ -64,7 +67,8 @@ Tensor halo_op(
bool transpose_mcast = true,
uint32_t reshard_num_cores_nhw = 0,
const MemoryConfig& output_memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
bool is_out_tiled = true);
bool is_out_tiled = true,
bool enable_split_reader = false);

} // namespace halo

Expand Down
6 changes: 4 additions & 2 deletions ttnn/cpp/ttnn/operations/sliding_window/halo/halo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Tensor HaloOperation::invoke(
bool transpose_mcast,
uint32_t reshard_num_cores_nhw,
const MemoryConfig& output_memory_config,
bool is_out_tiled) {
bool is_out_tiled,
bool enable_split_reader) {
return halo_op(
input_tensor,
config,
Expand All @@ -25,6 +26,7 @@ Tensor HaloOperation::invoke(
transpose_mcast,
reshard_num_cores_nhw,
std::move(output_memory_config),
is_out_tiled);
is_out_tiled,
enable_split_reader);
}
}; // namespace ttnn::operations::sliding_window::halo
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/sliding_window/halo/halo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ struct HaloOperation {
bool transpose_mcast = true,
uint32_t reshard_num_cores_nhw = 0,
const MemoryConfig& output_memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
bool is_out_tiled = true);
bool is_out_tiled = true,
bool enable_split_reader = false);

// invoke can be overloaded as many times as needed to provide all desired APIs
};
Expand Down

0 comments on commit 1da69fe

Please sign in to comment.