From 1da69fe636bcff0383fb25ec8d2acc16210a2fee Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Fri, 28 Feb 2025 09:53:25 +0000 Subject: [PATCH] Add test case and allow user to choose whther to enable split reader or not for conv2d. Signed-off-by: Nilaykumar Patel --- .../unit_tests/operations/test_new_conv2d.py | 81 +++++++++++++++++++ .../ttnn/operations/conv/conv2d/conv2d.cpp | 3 +- .../operations/conv/conv2d/conv2d_pybind.cpp | 5 +- .../conv/conv2d/device/conv2d_op.hpp | 8 +- .../device/kernels/dataflow/halo_gather.cpp | 4 + .../untilize_with_halo_v2_program_factory.hpp | 2 +- .../halo/device/halo_device_operation.cpp | 12 ++- .../halo/device/halo_device_operation.hpp | 10 ++- .../operations/sliding_window/halo/halo.cpp | 6 +- .../operations/sliding_window/halo/halo.hpp | 3 +- 10 files changed, 119 insertions(+), 15 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index dbc28079e16..811611deb65 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -72,6 +72,7 @@ def run_conv( weight_mesh_mapper=None, output_mesh_composer=None, enable_split_reader=False, + enable_halo_split_reader=False, activation="", ): if isinstance(device, ttnn.MeshDevice): @@ -138,6 +139,7 @@ def run_conv( enable_subblock_padding=False, output_layout=output_layout, activation=activation, + enable_halo_split_reader=enable_halo_split_reader, ) compute_config = ttnn.init_device_compute_kernel_config( device.arch(), @@ -2852,3 +2854,82 @@ def test_block_sharding_relu_act_block_h( shard_layout=shard_layout, activation=activation, ) + +@pytest.mark.parametrize("batch", [1]) +@pytest.mark.parametrize( + "output_channels, input_channels, input_height, input_width", + ( + (4, 32, 288, 288), + (32, 48, 284, 284), + (48, 56, 280, 280), + (56, 64, 272, 272), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize( + "kernel, dilation, padding", + [ + [5, 2, 2], + [3, 8, 1], + ], +) +@pytest.mark.parametrize("stride", [1]) +@pytest.mark.parametrize("enable_halo_split_reader", [True]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384*2}], indirect=True) +def test_halo_split_reader( + device, + torch_tensor_map, + batch, + output_channels, + input_channels, + input_height, + input_width, + weights_dtype, + activations_dtype, + math_fidelity, + kernel, + dilation, + padding, + stride, + enable_halo_split_reader +): + config_override = {} + + run_conv( + device=device, + torch_tensor_map=torch_tensor_map, + activations_dtype=activations_dtype, + weights_dtype=weights_dtype, + batch_size=batch, + output_channels=output_channels, + input_channels=input_channels, + input_height=input_height, + input_width=input_width, + filter_height=kernel, + filter_width=kernel, + stride_h=stride, + stride_w=stride, + pad_h=padding, + pad_w=padding, + config_override=config_override, + dilation=dilation, + math_fidelity=math_fidelity, + output_layout=ttnn.TILE_LAYOUT, + debug=False, + groups=1, + has_bias=True, + shard_layout=None, + memory_config=None, + input_mesh_mapper=None, + weight_mesh_mapper=None, + output_mesh_composer=None, + enable_halo_split_reader=enable_halo_split_reader, + ) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index a3928a36629..b910326870d 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -181,7 +181,8 @@ Result conv2d( parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, 0, input_tensor_post_tm.memory_config(), - true); + true, + conv_config.enable_halo_split_reader); if (conv_config.deallocate_activation) { input_tensor_post_tm.deallocate(/*force*/ true); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 0591ed02d0c..af7d8e397ee 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -335,6 +335,7 @@ void py_bind_conv2d(py::module& module) { bool, bool, bool, + bool, bool>(), py::kw_only(), py::arg("dtype") = DataType::BFLOAT16, @@ -354,7 +355,8 @@ void py_bind_conv2d(py::module& module) { py::arg("enable_act_double_buffer") = false, py::arg("enable_weights_double_buffer") = false, py::arg("enable_split_reader") = false, - py::arg("enable_subblock_padding") = false); + py::arg("enable_subblock_padding") = false, + py::arg("enable_halo_split_reader") = false); py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype); py_conv_config.def_readwrite("weights_dtype", &Conv2dConfig::weights_dtype); py_conv_config.def_readwrite("activation", &Conv2dConfig::activation); @@ -373,6 +375,7 @@ void py_bind_conv2d(py::module& module) { py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer); py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader); py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding); + py_conv_config.def_readwrite("enable_halo_split_reader", &Conv2dConfig::enable_halo_split_reader); py_conv_config.def("__repr__", [](const Conv2dConfig& config) { return fmt::format("{}", config); }); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index 04557524b76..79882019fb4 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -73,6 +73,8 @@ struct Conv2dConfig { bool enable_split_reader = false; bool enable_subblock_padding = false; + + bool enable_halo_split_reader = false; static constexpr auto attribute_names = std::make_tuple( "dtype", "weights_dtype", @@ -91,7 +93,8 @@ struct Conv2dConfig { "enable_act_double_buffer", "enable_weights_double_buffer", "enable_split_reader", - "enable_subblock_padding"); + "enable_subblock_padding", + "enable_halo_split_reader"); const auto attribute_values() const { return std::make_tuple( std::cref(this->dtype), @@ -111,7 +114,8 @@ struct Conv2dConfig { std::cref(this->enable_act_double_buffer), std::cref(this->enable_weights_double_buffer), std::cref(this->enable_split_reader), - std::cref(this->enable_subblock_padding)); + std::cref(this->enable_subblock_padding), + std::cref(this->enable_halo_split_reader)); } }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp index 3e72a149547..a7806862a8d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp @@ -56,20 +56,24 @@ void copy_sticks_async( if constexpr (enable_split_reader) { if constexpr (is_remote_config) { if constexpr (is_reader) { + // Skip every odd iteration for readers in remote configuration if (iteration % 2 == 1) { continue; } } else { + // Skip every even iteration for writer in remote configuration if (iteration % 2 == 0) { continue; } } } else { if constexpr (is_reader) { + // Skip every even iteration for readers in local configuration if (iteration % 2 == 0) { continue; } } else { + // Skip every odd iteration for writer in local configuration if (iteration % 2 == 1) { continue; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp index e2858d37741..d9bbd84b195 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp @@ -22,5 +22,5 @@ tt::tt_metal::operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( Tensor& output_tensor, const bool capture_buffers, // Used by halo op to cache internally created config buffers with the program Untilize // with Halo V2 op takes them as inputs from the user, so doesn't capture - const bool enable_split_reader = false); + const bool enable_split_reader = true); } // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp index 8a52e6534aa..e6d4b0ccb17 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp @@ -130,7 +130,8 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program( remote_read_, transpose_mcast_, output_tensor, - /*capture_buffers=*/true)}; + /*capture_buffers=*/true, + enable_split_reader_)}; } Tensor halo_op( @@ -141,7 +142,8 @@ Tensor halo_op( bool transpose_mcast, uint32_t reshard_num_cores_nhw, const MemoryConfig& output_memory_config, - bool is_out_tiled) { + bool is_out_tiled, + bool enable_split_reader) { TT_FATAL(input_tensor.memory_config().is_sharded(), "Halo expects sharded input tensor"); TT_FATAL( input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || @@ -159,7 +161,8 @@ Tensor halo_op( transpose_mcast, reshard_num_cores_nhw, output_memory_config, - is_out_tiled]( + is_out_tiled, + enable_split_reader]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -192,7 +195,8 @@ Tensor halo_op( .reshard_num_cores_nhw_ = reshard_num_cores_nhw, .max_out_nsticks_per_core_ = max_out_nsticks_per_core, .output_memory_config_ = output_memory_config, - .is_out_tiled_ = is_out_tiled}, + .is_out_tiled_ = is_out_tiled, + .enable_split_reader_ = enable_split_reader}, {input_tensor}); }; diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp index 88e502181f1..c02fbf15e6a 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp @@ -25,6 +25,7 @@ struct HaloDeviceOperation { uint32_t max_out_nsticks_per_core_; MemoryConfig output_memory_config_; bool is_out_tiled_; + bool enable_split_reader_; void validate(const std::vector& input_tensors) const; std::vector compute_output_specs(const std::vector& input_tensors) const; @@ -41,7 +42,8 @@ struct HaloDeviceOperation { "reshard_num_cores_nhw_", "max_out_nsticks_per_core_", "output_memory_config_", - "is_out_tiled_"); + "is_out_tiled_", + "enable_split_reader_"); const auto attribute_values() const { return std::make_tuple( std::cref(config_), @@ -52,7 +54,8 @@ struct HaloDeviceOperation { std::cref(reshard_num_cores_nhw_), std::cref(max_out_nsticks_per_core_), std::cref(output_memory_config_), - std::cref(is_out_tiled_)); + std::cref(is_out_tiled_), + std::cref(enable_split_reader_)); } }; @@ -64,7 +67,8 @@ Tensor halo_op( bool transpose_mcast = true, uint32_t reshard_num_cores_nhw = 0, const MemoryConfig& output_memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - bool is_out_tiled = true); + bool is_out_tiled = true, + bool enable_split_reader = false); } // namespace halo diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.cpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.cpp index b054af4cc72..d7f8cc0ec40 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.cpp @@ -16,7 +16,8 @@ Tensor HaloOperation::invoke( bool transpose_mcast, uint32_t reshard_num_cores_nhw, const MemoryConfig& output_memory_config, - bool is_out_tiled) { + bool is_out_tiled, + bool enable_split_reader) { return halo_op( input_tensor, config, @@ -25,6 +26,7 @@ Tensor HaloOperation::invoke( transpose_mcast, reshard_num_cores_nhw, std::move(output_memory_config), - is_out_tiled); + is_out_tiled, + enable_split_reader); } }; // namespace ttnn::operations::sliding_window::halo diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.hpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.hpp index 31df09955ea..07f2d313a02 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.hpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/halo.hpp @@ -20,7 +20,8 @@ struct HaloOperation { bool transpose_mcast = true, uint32_t reshard_num_cores_nhw = 0, const MemoryConfig& output_memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - bool is_out_tiled = true); + bool is_out_tiled = true, + bool enable_split_reader = false); // invoke can be overloaded as many times as needed to provide all desired APIs };