From f3d8fac9a2e2a603d7b6ed7d15afe9529ad1b1a3 Mon Sep 17 00:00:00 2001 From: Kartik Paigwar <132708568+kpaigwar@users.noreply.github.com> Date: Thu, 6 Mar 2025 13:51:35 -0500 Subject: [PATCH] Trying again: Revert Fix Minimal All Reduce for Llama shapes (#18731) ### Description This is a duplicate PR for https://github.com/tenstorrent/tt-metal/pull/18217 which got reverted due to symlink issue. The issue is now fixed and post-commit pipelines are passing. ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/runs/13702739699) CI passes - [x] [TG Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/13688165179) CI passes --------- Co-authored-by: avoraTT Co-authored-by: yugaoTT --- .../tests/test_ccl_async_perf_TG_llama.py | 120 +++++ models/perf/device_perf_utils.py | 77 +++ tests/nightly/tg/ccl/test_new_all_reduce.py | 1 + .../ccl/test_all_gather_TG_post_commit.py | 98 ++-- .../operations/ccl/test_all_reduce_async.py | 6 +- .../operations/ccl/test_ccl_async_TG_llama.py | 289 +++------- .../operations/ccl/test_ccl_common.py | 1 - .../operations/ccl/test_new_all_reduce.py | 484 +++++++++++++++++ ttnn/CMakeLists.txt | 2 + .../device/all_gather_async_op.cpp | 38 +- .../device/all_gather_async_op.hpp | 4 +- ..._gather_async_program_minimal_variants.cpp | 24 +- ...er.cpp => llama_shapes_sharded_reader.cpp} | 0 ...er.cpp => llama_shapes_sharded_writer.cpp} | 0 .../ccl/all_reduce_async/all_reduce_async.cpp | 25 + .../ccl/all_reduce_async/all_reduce_async.hpp | 11 + .../all_reduce_async_pybind.cpp | 33 ++ .../device/all_reduce_async_op.cpp | 305 +++++++++++ .../device/all_reduce_async_op.hpp | 141 +++++ ..._reduce_async_program_minimal_variants.cpp | 493 ++++++++++++++++++ .../device/kernels/compute/reduction.cpp | 61 +++ .../kernels/dataflow/reduction_receiver.cpp | 27 + .../device/kernels/dataflow/worker_reader.cpp | 57 ++ .../device/kernels/dataflow/worker_writer.cpp | 194 +++++++ 24 files changed, 2212 insertions(+), 279 deletions(-) create mode 100644 models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py create mode 120000 tests/nightly/tg/ccl/test_new_all_reduce.py create mode 100644 tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py rename ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/{llama_post_binary_matmul_shape_reader.cpp => llama_shapes_sharded_reader.cpp} (100%) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/{llama_post_binary_matmul_shape_writer.cpp => llama_shapes_sharded_writer.cpp} (100%) create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp diff --git a/models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py b/models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py new file mode 100644 index 00000000000..da5f469cf62 --- /dev/null +++ b/models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn + +from models.perf.benchmarking_utils import BenchmarkData, BenchmarkProfiler +from models.perf.device_perf_utils import run_device_perf_detailed + + +@pytest.mark.parametrize( + "ag_type, warmup_iters, perf_target_us", + [ + ("sdpa", 10, 11), + ("binary_mult", 10, 12), + ("layernorm", 10, 8), + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_ag_tg_llama_perf( + ag_type, + warmup_iters, + perf_target_us, +): + profiler = BenchmarkProfiler() + benchmark_data = BenchmarkData() + step_name = f"all_gather_{ag_type}" + + subdir = "llama_ccl_perf" + command = ( + f"pytest tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py::test_all_gather_tg_llama -k {ag_type}" + ) + cols = ["DEVICE KERNEL"] + op_name = "AllGatherAsync" + warmup_iters = warmup_iters * 32 # 5 iterations per device + + profiler.start("run") + profiler.start(step_name) + results = run_device_perf_detailed(command, subdir, cols, op_name, has_signposts=True, warmup_iters=warmup_iters) + profiler.end(step_name) + profiler.end("run") + + # Get the measured performance + measured_min_us = results[cols[0]]["MIN"] / 1000 + measured_max_us = results[cols[0]]["MAX"] / 1000 + measured_avg_us = results[cols[0]]["AVG"] / 1000 + measured_std_us = results[cols[0]]["STD"] / 1000 + + logger.info(f"Measured performance: {measured_avg_us:.3f} us vs. target: {perf_target_us} us") + + # Save the measurement + benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-min-us", measured_min_us) + benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-max-us", measured_max_us) + benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-avg-us", measured_avg_us) + benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-std-us", measured_std_us) + benchmark_data.save_partial_run_json( + profiler, + run_type=f"all_gather", + ml_model_name="llama70b-tg-ccl", + ) + + assert measured_avg_us < perf_target_us, f"Performance target not met: {measured_avg_us} us > {perf_target_us} us" + + +@pytest.mark.parametrize( + "ar_type, warmup_iters, perf_target_us", + [ + ("ff2", 10, 29), + ("qkv", 10, 25), + ("ff1", 10, 30), + ("lm_head", 10, 70), + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_ar_tg_llama_perf( + ar_type, + warmup_iters, + perf_target_us, +): + profiler = BenchmarkProfiler() + benchmark_data = BenchmarkData() + step_name = f"all_reduce_{ar_type}" + + subdir = "llama_ccl_perf" + command = ( + f"pytest tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py::test_all_reduce_tg_llama -k {ar_type}" + ) + cols = ["DEVICE KERNEL"] + op_name = "AllReduceAsync" + warmup_iters = warmup_iters * 32 # 5 iterations per device + + profiler.start("run") + profiler.start(step_name) + results = run_device_perf_detailed(command, subdir, cols, op_name, has_signposts=True, warmup_iters=warmup_iters) + profiler.end(step_name) + profiler.end("run") + + # Get the measured performance + measured_min_us = results[cols[0]]["MIN"] / 1000 + measured_max_us = results[cols[0]]["MAX"] / 1000 + measured_avg_us = results[cols[0]]["AVG"] / 1000 + measured_std_us = results[cols[0]]["STD"] / 1000 + + logger.info(f"Measured performance: {measured_avg_us:.3f} us vs. target: {perf_target_us} us") + + # Save the measurement + benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-min-us", measured_min_us) + benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-max-us", measured_max_us) + benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-avg-us", measured_avg_us) + benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-std-us", measured_std_us) + benchmark_data.save_partial_run_json( + profiler, + run_type=f"all_reduce", + ml_model_name="llama70b-tg-ccl", + ) + + assert measured_avg_us < perf_target_us, f"Performance target not met: {measured_avg_us} us > {perf_target_us} us" diff --git a/models/perf/device_perf_utils.py b/models/perf/device_perf_utils.py index 7810be77750..dd3baf18a49 100644 --- a/models/perf/device_perf_utils.py +++ b/models/perf/device_perf_utils.py @@ -4,10 +4,14 @@ import json import time +import pandas as pd + from loguru import logger +from collections import defaultdict from tt_metal.tools.profiler.common import clear_profiler_runtime_artifacts from tt_metal.tools.profiler.process_model_log import ( + get_latest_ops_log_filename, post_process_ops_log, run_device_profiler, get_samples_per_s, @@ -49,6 +53,79 @@ def run_device_perf(command, subdir, num_iterations, cols, batch_size, has_signp return post_processed_results +# TODO: Move into process_model_log.py (#18698) +def post_process_ops_log_detailed( + output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False, detailed=False, warmup_iters=0 +): + filename = get_latest_ops_log_filename(output_logs_subdir) + df = pd.read_csv(filename) + + if has_signposts: + # there are explicit start and stop points in the model we want to measure between + markers = df[df["OP TYPE"] == "signpost"]["OP CODE"] + start = markers[markers == "start"].index[0] + stop = markers[markers == "stop"].index[0] + df = df.iloc[start + 1 : stop] + if op_name != "": + df = df[df["OP CODE"] == op_name] + + if warmup_iters > 0: + df = df.iloc[warmup_iters:] + + results = {} + for col in columns: + df_filtered = df[df[col] != "-"] + if sum_vals: + results[col] = df_filtered[col].astype(float).sum() + else: + results[col] = df_filtered[col].astype(float).to_numpy() + + if detailed: + results[f"AVG {col}"] = df_filtered[col].astype(float).mean() + results[f"MIN {col}"] = df_filtered[col].astype(float).min() + results[f"MAX {col}"] = df_filtered[col].astype(float).max() + results[f"STD {col}"] = df_filtered[col].astype(float).std() + + return results + + +def run_device_perf_detailed(command, subdir, cols, op_name="", has_signposts=False, warmup_iters=0): + duration_cols = [col + " DURATION [ns]" for col in cols] + + clear_profiler_runtime_artifacts() + + results = {} + for d_col in duration_cols: + results[f"AVG {d_col}"] = 0 + results[f"MIN {d_col}"] = float("inf") + results[f"MAX {d_col}"] = -float("inf") + results[f"STD {d_col}"] = 0 + + run_device_profiler(command, subdir) + r = post_process_ops_log_detailed( + subdir, duration_cols, op_name=op_name, has_signposts=has_signposts, detailed=True, warmup_iters=warmup_iters + ) + for d_col in duration_cols: + results[f"AVG {d_col}"] = r[f"AVG {d_col}"] + results[f"MIN {d_col}"] = r[f"MIN {d_col}"] + results[f"MAX {d_col}"] = r[f"MAX {d_col}"] + results[f"STD {d_col}"] = r[f"STD {d_col}"] + + post_processed_results = defaultdict(dict) + for col, d_col in zip(cols, duration_cols): + post_processed_results[col]["AVG"] = results[f"AVG {d_col}"] + post_processed_results[col]["MIN"] = results[f"MIN {d_col}"] + post_processed_results[col]["MAX"] = results[f"MAX {d_col}"] + post_processed_results[col]["STD"] = results[f"STD {d_col}"] + + logger.info( + f"\nTest: {command}" + f"\nPerformance statistics for op: {op_name}" + f"\n{json.dumps(post_processed_results, indent=4)}" + ) + return post_processed_results + + def check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False): expected_results = {} failed = False diff --git a/tests/nightly/tg/ccl/test_new_all_reduce.py b/tests/nightly/tg/ccl/test_new_all_reduce.py new file mode 120000 index 00000000000..e4c9af8b57a --- /dev/null +++ b/tests/nightly/tg/ccl/test_new_all_reduce.py @@ -0,0 +1 @@ +../../../ttnn/unit_tests/operations/ccl/test_new_all_reduce.py \ No newline at end of file diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 184bb454c1d..efac85c97c2 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -15,6 +15,9 @@ create_global_semaphore_with_same_address, ) from models.perf.benchmarking_utils import BenchmarkProfiler +from tracy import signpost + +NUM_BUFFERS = 8 def report_mismatches(golden, actual, max_printable=None): @@ -64,6 +67,7 @@ def run_with_trace( n_worker=None, n_buffer=None, num_iter=20, + warmup_iters=0, use_all_gather_async=False, profiler=BenchmarkProfiler(), ): @@ -98,47 +102,66 @@ def run_with_trace( # Capture trace logger.info("Capturing trace") - trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - for i in range(num_iter): - if use_all_gather_async: - logger.info("Running all-gather async") - tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor, - dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles[i] - if type(ccl_semaphore_handles) == list - else ccl_semaphore_handles, - num_links=num_links, - memory_config=output_mem_config, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - ) - else: - tt_out_tensor = ttnn.all_gather( - input_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - memory_config=output_mem_config, - topology=all_gather_topology, - ) - ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) - ttnn.synchronize_device(mesh_device) + + def capture_trace(n_iters): + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for i in range(n_iters): + if use_all_gather_async: + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS] + if type(ccl_semaphore_handles) == list + else ccl_semaphore_handles, + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + else: + tt_out_tensor = ttnn.all_gather( + input_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + ttnn.synchronize_device(mesh_device) + return trace_id + + if warmup_iters > 0: + trace_id_warmup = capture_trace(warmup_iters) + trace_id = capture_trace(num_iter) # Run the op logger.info("Starting Trace perf test...") + profiler.start("all-gather-async-trace-warmup") + if warmup_iters > 0: + ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) + ttnn.release_trace(mesh_device, trace_id_warmup) + ttnn.synchronize_device(mesh_device) + profiler.end("all-gather-async-trace-warmup") + profiler.start("all-gather-async-trace") + signpost("start") ttnn.execute_trace(mesh_device, trace_id, blocking=False) ttnn.release_trace(mesh_device, trace_id) ttnn.synchronize_device(mesh_device) + signpost("stop") profiler.end("all-gather-async-trace") - logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s") - logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s") - logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us") + time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration( + "all-gather-async-trace-warmup" + ) + effective_iter = num_iter - warmup_iters + logger.info(f"Time taken e2e: {time_taken} s") + logger.info(f"Time per iter e2e: {time_taken / effective_iter} s") + logger.info(f"Time per iter e2e: {time_taken / effective_iter * 1e6} us") return tt_out_tensor @@ -160,6 +183,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( output_shard_spec: ttnn.ShardSpec = None, num_all_gather_instances: int = 1, num_iters: int = 1, + warmup_iters: int = 0, cluster_axis: int = 0, tile=(32, 32), trace_mode=False, @@ -257,7 +281,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( # create global semaphore handles ccl_semaphore_handles = [ - create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) + create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS) ] try: # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) @@ -274,11 +298,13 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( enable_persistent_fabric=enable_persistent_fabric, all_gather_topology=ttnn.Topology.Linear, num_iter=num_iters, + warmup_iters=warmup_iters, use_all_gather_async=use_all_gather_async, profiler=profiler, ) else: + signpost("start") for i in range(num_iters): if use_all_gather_async: logger.info("Running all-gather async") @@ -288,7 +314,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( cluster_axis=cluster_axis, mesh_device=mesh_device, topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles[i], + multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS], num_links=num_links, memory_config=output_mem_config, subdevice_id=worker_sub_device_id, @@ -305,7 +331,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( topology=ttnn.Topology.Linear, ) ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group) - + signpost("stop") except Exception as e: logger.error(f"Exception: {e}") raise e diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py index 4adef9d5990..542a7765a99 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py @@ -45,8 +45,7 @@ def run_all_reduce_test( if teardown_persistent_fabric: assert enable_persistent_fabric - for d in mesh_device.get_devices(): - ttnn.enable_program_cache(d) + ttnn.synchronize_device(mesh_device) sub_device_stall_group = [] compute_grid_size = mesh_device.compute_with_storage_grid_size() @@ -324,8 +323,7 @@ def run_all_reduce_with_mesh_tensor_along_row( if teardown_persistent_fabric: assert enable_persistent_fabric - for d in mesh_device.get_devices(): - ttnn.enable_program_cache(d) + ttnn.synchronize_device(mesh_device) sub_device_stall_group = [] compute_grid_size = mesh_device.compute_with_storage_grid_size() diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index fe967467e14..60f1f8a65e5 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -6,25 +6,18 @@ import pytest from loguru import logger import ttnn -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull -from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( - create_and_load_sub_device_manager_with_fabric_interface, - teardown_fabric_interface, - create_global_semaphore_with_same_address, -) from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( run_line_all_gather_on_TG_with_mesh_tensor_along_rows, ) -from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_TG_nightly import ( - run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows, -) -from tests.ttnn.unit_tests.operations.ccl.test_all_reduce_async import ( - run_all_reduce_with_mesh_tensor_along_row, +from tests.ttnn.unit_tests.operations.ccl.test_new_all_reduce import ( + run_all_reduce_impl, ) -from models.perf.benchmarking_utils import BenchmarkProfiler +from models.perf.benchmarking_utils import BenchmarkData, BenchmarkProfiler + +NUM_ITERATIONS = 55 PREFETCHER_NOC1_RING = [ (6, 6), @@ -74,6 +67,13 @@ def get_core_range_set(output_core_grid): return output_core_range_set +CORE_RANGE_SET_1x1 = ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)), + } +) + + # Enumerate the post-commit cases explicitly @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( @@ -89,14 +89,14 @@ def get_core_range_set(output_core_grid): ], ) @pytest.mark.parametrize( - "num_iters", + "num_iters, warmup_iters", [ - 5000, + (NUM_ITERATIONS, 10), ], ) @pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @pytest.mark.parametrize( - "tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout, perf_target_us", + "tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout", ( ( # AllGather after SDPA ttnn.TensorMemoryLayout.HEIGHT_SHARDED, @@ -112,7 +112,6 @@ def get_core_range_set(output_core_grid): } ), ttnn.TILE_LAYOUT, - 32, ), ( # AllGather after Binary Mult+Silu ttnn.TensorMemoryLayout.WIDTH_SHARDED, @@ -123,15 +122,29 @@ def get_core_range_set(output_core_grid): (32, 160), get_core_range_set(PREFETCHER_NOC1_RING), ttnn.TILE_LAYOUT, - 25, + ), + ( # AllGather for layernorm + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (1, 1, 32, 128), + 3, + (32, 32), + CORE_RANGE_SET_1x1, + (32, 128), + CORE_RANGE_SET_1x1, + ttnn.TILE_LAYOUT, ), ), + ids=[ + "sdpa", + "binary_mult", + "layernorm", + ], ) @pytest.mark.parametrize("replication_factor", [8]) @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) @pytest.mark.parametrize("device_params", [{"trace_region_size": 17068032}], indirect=True) -def test_line_all_gather_sharded_on_TG_rows_llama( +def test_all_gather_tg_llama( mesh_device, num_devices, output_shape, @@ -150,7 +163,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( enable_async, replication_factor, num_iters, - perf_target_us, + warmup_iters, ): if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") @@ -185,6 +198,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( function_level_defaults, enable_async=enable_async, num_iters=num_iters, + warmup_iters=warmup_iters, input_shard_spec=input_shard_spec, output_shard_spec=output_shard_spec, num_all_gather_instances=replication_factor, @@ -197,231 +211,78 @@ def test_line_all_gather_sharded_on_TG_rows_llama( teardown_persistent_fabric=True, ) - latency_us = profiler.get_duration("all-gather-async-trace") / num_iters * 1e6 - if perf_target_us is not None: - assert ( - latency_us < perf_target_us - ), f"Measured latency {latency_us} us is greater than target {perf_target_us} us" - @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "num_devices, num_links", + "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ - (4, 2), + ([1, 1, 32, 2048], 0, 4, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 3, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 3, 24, 24), # FF1 all reduce + ([1, 1, 32, 16 * 1024], 1, 3, 32, 32), # LM head all reduce ], -) -@pytest.mark.parametrize( - "tensor_mem_layout, per_chip_input_shape, dim, input_shard_shape,shard_grid,layout", - ( - ( # ReduceScatter After FF1/3 (~100 us) - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - (1, 1, 32, 3840), - 3, - (32, 160), - ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}), - ttnn.TILE_LAYOUT, - ), - ), -) -@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) -@pytest.mark.parametrize( - "input_dtype", - [ - ttnn.bfloat16, - # ttnn.bfloat8_b, - ], -) -@pytest.mark.parametrize( - "buffer_type", - [ - ttnn.BufferType.L1, - ], -) -@pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("replication_factor", [8]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_reduce_scatter_sharded_on_TG_rows_llama( - mesh_device, - num_devices, - per_chip_input_shape, - tensor_mem_layout, - input_shard_shape, - shard_grid, - shard_grid_orientation, - dim, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async, - replication_factor, - num_iters=10, -): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - input_shard_spec = ttnn.ShardSpec( - shard_grid, - input_shard_shape, - shard_grid_orientation, - ) - - logger.warning("sharding not used due to issue #16699") - - run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( - mesh_device, - num_devices, - per_chip_input_shape, - ttnn.TensorMemoryLayout.INTERLEAVED, # tensor_mem_layout, - dim, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, - # input_shard_spec=input_shard_spec, - num_iters=num_iters, - num_reduce_scatter_instances=replication_factor, - cluster_axis=1, - use_reduce_scatter_async=True, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, - ) - - -# Enumerate the post-commit cases explicitly -@skip_for_grayskull("Requires eth connected devices to run") -@pytest.mark.parametrize( - "num_devices, num_links, per_chip_output_shape, layout", - [ - (4, 1, [1, 1, 32, 1280], ttnn.TILE_LAYOUT), # AllReduce after QKV (~110 us) + ids=[ + "ff2", + "qkv", + "ff1", + "lm_head", ], ) @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize( - "buffer_type", + "num_iters, warmup_iters", [ - ttnn.BufferType.L1, + (NUM_ITERATIONS, 10), ], ) -@pytest.mark.parametrize("replication_factor", [8]) # 1, 8]) @pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_all_reduce_on_TG_rows_llama( - mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async, - replication_factor, - num_iters=10, -): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - - logger.warning("sharding not used due to issue #16699") - - run_all_reduce_with_mesh_tensor_along_row( - mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, - num_iters=num_iters, - num_all_reduce_instances=replication_factor, - cluster_axis=1, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, - ) - - -@skip_for_grayskull("Requires eth connected devices to run") -@pytest.mark.parametrize( - "num_devices, num_links, per_chip_output_shape, layout", - [ - (8, 1, [1, 1, 32, 2048], ttnn.TILE_LAYOUT), # AllReduce after DO and AllReduce after FF2 (~240 us) - # multi-links fail https://github.com/tenstorrent/tt-metal/issues/16699 - ], -) +@pytest.mark.parametrize("trace_mode", [True]) @pytest.mark.parametrize( - "input_dtype", - [ - ttnn.bfloat16, - ], + "device_params", + [{"trace_region_size": 23887872}], + indirect=True, ) @pytest.mark.parametrize( - "buffer_type", + "mesh_device", [ - ttnn.BufferType.L1, + (8, 4), ], + indirect=True, ) -@pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("replication_factor", [4]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_all_reduce_on_TG_cols_llama( +def test_all_reduce_tg_llama( mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, + output_shape, + cluster_axis, input_dtype, - layout, - buffer_type, + num_links, + input_num_cores, + output_num_cores, + num_iters, + warmup_iters, + enable_async, + trace_mode, use_program_cache, function_level_defaults, - enable_async, - replication_factor, - num_iters=10, ): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - - logger.warning("sharding not used due to issue #16699") + profiler = BenchmarkProfiler() - run_all_reduce_with_mesh_tensor_along_row( + run_all_reduce_impl( mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, + output_shape, + cluster_axis, input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, + num_links, + input_num_cores, + output_num_cores, num_iters=num_iters, - num_all_reduce_instances=replication_factor, - cluster_axis=0, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, + warmup_iters=warmup_iters, + enable_async=enable_async, + trace_mode=trace_mode, + validate_all=False, + profiler=profiler, ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py index ff41460d520..1029c113749 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py @@ -39,7 +39,6 @@ def teardown_fabric_interface(mesh_device): def create_global_semaphore_with_same_address(mesh_device, cores, initial_value): semaphore_handles = ttnn.create_global_semaphore_with_same_address(mesh_device, cores, initial_value) addrs = ttnn.get_global_semaphore_address(semaphore_handles) - logger.debug(f"from remote semaphore handle addresses: {addrs}") # assert all addresses are the same assert len(set(addrs)) == 1 return semaphore_handles diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py new file mode 100644 index 00000000000..5cc3df68615 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -0,0 +1,484 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import math +from time import time +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull +from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( + create_and_load_sub_device_manager_with_fabric_interface, + teardown_fabric_interface, + create_global_semaphore_with_same_address, +) + +from tests.tt_eager.python_api_testing.unit_testing.misc.test_matmul_1d_gather_in0 import ( + num_cores_to_rectangle_grid, + round_up, +) +from models.perf.benchmarking_utils import BenchmarkProfiler +from tracy import signpost + + +def check_mesh_tensor_alloc(tensor): + device_tensors = ttnn.get_device_tensors(tensor) + buffer_addr = device_tensors[0].buffer_address() + + if len(device_tensors) > 1: + for i in range(1, len(device_tensors)): + addr = device_tensors[i].buffer_address() + if not addr == buffer_addr: + return False + return True + + +def run_all_reduce_impl( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + input_num_cores, + output_num_cores, + loopback_size=1, + num_iters=1, + warmup_iters=0, + enable_async=False, + trace_mode=False, + validate_all=True, + profiler=BenchmarkProfiler(), +): + cluster_shape = (8, 4) + + create_persistent_fabric = True + teardown_persistent_fabric = True + enable_persistent_fabric = True + if num_iters < 1: + pytest.fail("num_iters must be >= 1") + # Use Async mode based on test input config + mesh_device.enable_async(enable_async) + + if enable_async: + logger.info(f"Using Async Mode for All Gather Op Dispatch") + + ################################## + ##### Set up fabric stuff + ################################## + compute_grid_size = mesh_device.compute_with_storage_grid_size() + ccl_sub_device_crs = ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1))} + ) + worker_sub_device = ttnn.SubDevice( + [ + ccl_sub_device_crs, + ] + ) + worker_sub_device_id = ttnn.SubDeviceId(0) + sub_device_stall_group = [worker_sub_device_id] + if create_persistent_fabric: + mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( + mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric + ) + mesh_device.set_sub_device_stall_group(sub_device_stall_group) + + # create global semaphore handles + num_buffers = 8 + ccl_semaphore_handles = [ + create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_buffers) + ] + + logger.info(f"Output shape: {output_shape}") + + try: + ################################## + ##### Set up input tensors/configs + ################################## + + ##### FF2 Case ##### + M, N = output_shape[2:] + N_per_shard = round_up(math.ceil(N / input_num_cores), ttnn.TILE_SIZE) + output_N_per_shard = round_up(math.ceil(N / output_num_cores), ttnn.TILE_SIZE) + input_shape = [*cluster_shape, M, N] + intermediate_shape = [*input_shape[:-1], N * cluster_shape[cluster_axis]] + + CORE_RANGE = [(x, y) for y in range(compute_grid_size.y) for x in range(compute_grid_size.x)] + core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange( + ttnn.CoreCoord(x, y), + ttnn.CoreCoord(x, y), + ) + for x, y in CORE_RANGE[:input_num_cores] + ] + ) + input_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + core_range_set, + [M, N_per_shard], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + output_core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange( + ttnn.CoreCoord(x, y), + ttnn.CoreCoord(x, y), + ) + for x, y in CORE_RANGE[:output_num_cores] + ] + ) + output_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + output_core_range_set, + [M, output_N_per_shard], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + intermediate_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + output_core_range_set, + [M, output_N_per_shard * cluster_shape[cluster_axis]], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + + logger.info(f"Input shape: {input_shape[2:]}, Padded shape: {[M, N_per_shard * input_num_cores]}") + input_tensor = torch.randn(input_shape) + tt_input_tensor = ttnn.from_torch( + input_tensor, + device=mesh_device, + layout=ttnn.TILE_LAYOUT, + dtype=input_dtype, + memory_config=input_mem_config, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(0, 1), mesh_shape=cluster_shape), + ) + check_mesh_tensor_alloc(tt_input_tensor) + + intermediate_tensor = torch.zeros(intermediate_shape) + tt_intermediate_tensors = [] + for i in range(num_buffers): + tt_intermediate_tensor = ttnn.from_torch( + intermediate_tensor, + device=mesh_device, + layout=ttnn.TILE_LAYOUT, + dtype=input_dtype, + memory_config=intermediate_mem_config, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(0, 1), mesh_shape=cluster_shape), + ) + + # Validate that the tensor is allocated in same location across devices + check_mesh_tensor_alloc(tt_intermediate_tensor) + tt_intermediate_tensors.append(tt_intermediate_tensor) + + # All-Reduce Golden + # Inputs reduce sequentially for 10 iters + output_tensor_goldens_list = [] + for i in range(num_iters): + if i % loopback_size == 0: + ar_input_tensor = input_tensor + + output_tensor_goldens_list.append(torch.sum(ar_input_tensor, dim=cluster_axis)) + ar_input_tensor = torch.concat( + [output_tensor_goldens_list[-1].unsqueeze(cluster_axis)] * cluster_shape[cluster_axis], dim=cluster_axis + ) + + ################################## + ##### Run the op + ################################## + + def run_op(n_iters, store_all_results=True): + outs = [] + for i in range(n_iters): + if i % loopback_size == 0: + tt_input = tt_input_tensor + + out = ttnn.experimental.all_reduce_async( + tt_input, + tt_intermediate_tensors[i % num_buffers], + cluster_axis=cluster_axis, + mesh_device=mesh_device, + multi_device_global_semaphore=ccl_semaphore_handles[i % num_buffers], + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + num_links=num_links, + subdevice_id=worker_sub_device_id, + ) + if not trace_mode: + ttnn.synchronize_device(mesh_device) + if store_all_results: + outs.append(out) + + # Loop back the output to the input + if loopback_size != 1: + tt_input = ttnn.reshard(out, input_mem_config) + + if store_all_results: + return outs + else: + return [out] + + if trace_mode: + ##### Compile Model ##### + logger.info("Compiling model") + tt_outs = run_op(num_iters, store_all_results=validate_all) + + ##### Capture Trace ##### + logger.info("Capturing trace") + if warmup_iters > 0: + trace_id_warmup = ttnn.begin_trace_capture(mesh_device, cq_id=0) + tt_outs = run_op(warmup_iters, store_all_results=validate_all) + ttnn.end_trace_capture(mesh_device, trace_id_warmup, cq_id=0) + ttnn.synchronize_device(mesh_device) + + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + tt_outs = run_op(num_iters, store_all_results=validate_all) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + ttnn.synchronize_device(mesh_device) + + ##### Run Trace ##### + logger.info("Starting Trace perf test...") + profiler.start("all-reduce-async-trace-warmup") + if warmup_iters > 0: + ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) + ttnn.release_trace(mesh_device, trace_id_warmup) + ttnn.synchronize_device(mesh_device) + profiler.end("all-reduce-async-trace-warmup") + + signpost("start") + profiler.start("all-reduce-async-trace") + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ttnn.release_trace(mesh_device, trace_id) + ttnn.synchronize_device(mesh_device) + profiler.end("all-reduce-async-trace") + signpost("stop") + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" + ) + effective_iter = num_iters - warmup_iters + logger.info(f"Time taken e2e: {time_taken} s") + logger.info(f"Time per iter e2e: {time_taken / effective_iter} s") + logger.info(f"Time per iter e2e: {time_taken / effective_iter * 1e6} us") + + else: + signpost("start") + tt_outs = run_op(num_iters, store_all_results=validate_all) + signpost("stop") + + ################################## + ##### Validation + ################################## + def validate(tt_out_tensor, output_tensor): + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + # get_device_tensors returns row major, so we need to select the correct golden tensor + if cluster_axis == 0: + output_tensor_ = output_tensor[i % cluster_shape[not (cluster_axis)]].unsqueeze(0).unsqueeze(0) + else: + output_tensor_ = output_tensor[i // cluster_shape[cluster_axis]].unsqueeze(0).unsqueeze(0) + + tt_output_tensor = t.cpu().to_torch() + # logger.info(f"Checking for device {t.device().id()}") + + if input_dtype == ttnn.bfloat16: + eq, output = comp_pcc(tt_output_tensor, output_tensor_) + else: + eq, output = comp_pcc(tt_output_tensor, output_tensor_) + assert eq, f"{i} FAILED: {output}" + logger.info(f"PCC output is: {output}") + + if validate_all: + for tensor_index in range(len(tt_outs)): + tt_out_tensor = tt_outs[tensor_index] + output_tensor = output_tensor_goldens_list[tensor_index] + validate(tt_out_tensor, output_tensor) + else: + tt_out_tensor = tt_outs[-1] + output_tensor = output_tensor_goldens_list[-1] + validate(tt_out_tensor, output_tensor) + + for i in range(mesh_device.get_num_devices()): + reshard_op_cnt = 1 if loopback_size > 1 else 0 + assert ( + mesh_device.get_devices()[i].num_program_cache_entries() == 1 + reshard_op_cnt + or mesh_device.get_devices()[i].num_program_cache_entries() == num_iters + reshard_op_cnt + ), f"Device {i} has {mesh_device.get_devices()[i].num_program_cache_entries()} program cache entries" + + finally: + if enable_persistent_fabric and teardown_persistent_fabric: + mesh_device.reset_sub_device_stall_group() + t1 = time() + teardown_fabric_interface(mesh_device) + t2 = time() + logger.info(f"Teardown time: {t2 - t1}") + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(900) +@pytest.mark.parametrize( + "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", + [ + ([1, 1, 32, 2048], 0, 4, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 3, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 3, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 3, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 16 * 1024], 1, 3, 32, 32), # LM Head all reduce + ([1, 1, 32, 1280], 1, 2, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 2, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 2, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 16 * 1024], 1, 2, 32, 32), # LM Head all reduce + ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 16 * 1024], 1, 1, 32, 32), # LM Head all reduce + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize( + "num_iters, warmup_iters", + [ + (1000, 100), + ], +) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("trace_mode", [True]) +@pytest.mark.parametrize( + "device_params", + [{"trace_region_size": 23887872}], + indirect=True, +) +@pytest.mark.parametrize( + "mesh_device", + [ + (8, 4), + ], + indirect=True, +) +def test_all_reduce( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + input_num_cores, + output_num_cores, + num_iters, + warmup_iters, + enable_async, + trace_mode, + use_program_cache, + function_level_defaults, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + + profiler = BenchmarkProfiler() + + run_all_reduce_impl( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + input_num_cores, + output_num_cores, + num_iters=num_iters, + warmup_iters=warmup_iters, + enable_async=enable_async, + trace_mode=trace_mode, + validate_all=False, + profiler=profiler, + ) + + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" + ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(600) +@pytest.mark.parametrize( + "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", + [ + ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize( + "num_iters, warmup_iters", + [ + (100, 10), + ], +) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("trace_mode", [True]) +@pytest.mark.parametrize( + "device_params", + [{"trace_region_size": 23887872}], + indirect=True, +) +@pytest.mark.parametrize( + "mesh_device", + [ + (8, 4), + ], + indirect=True, +) +def test_all_reduce_loopback( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + input_num_cores, + output_num_cores, + num_iters, + warmup_iters, + enable_async, + trace_mode, + use_program_cache, + function_level_defaults, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + + run_all_reduce_impl( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + input_num_cores, + output_num_cores, + loopback_size=4, + num_iters=num_iters, + warmup_iters=warmup_iters, + enable_async=enable_async, + trace_mode=trace_mode, + validate_all=False, + ) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 2b7d8265ade..f7bfd486a3c 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -260,6 +260,8 @@ set(CCL_SRC_EXPERIMENTAL ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce/all_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index cfd21f7150f..1dc8c1cae46 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -146,7 +146,7 @@ AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) log_trace(tt::LogOp, "[select_version] output_is_sharded: {}", output_is_sharded); if (input_is_sharded && output_is_sharded) { - // Check for first llama post binary matmul case + // Check for llama post binary mult+silu case if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 1 && input_tensor_shape[2] == 32 && input_tensor_shape[3] == 960 && input_tensor_memory_config.buffer_type == BufferType::L1 && output_mem_config.buffer_type == BufferType::L1 && @@ -157,10 +157,13 @@ AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) output_mem_config.shard_spec->shape[0] == 32 && output_mem_config.shard_spec->shape[1] == 160 && input_shard_num_cores == 30 && output_shard_num_cores == 24) { - return AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL; + log_trace( + tt::LogOp, + "Matching conditions for Llama post binary mult+silu, using LLAMA_MINIMAL_SHARDED implementation"); + return AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED; } - // Check for second llama post binary matmul case + // Check for llama post SDPA case if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 8 && input_tensor_shape[2] == 32 && input_tensor_shape[3] == 128 && input_tensor_memory_config.buffer_type == BufferType::L1 && output_mem_config.buffer_type == BufferType::L1 && @@ -171,11 +174,26 @@ AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) output_mem_config.shard_spec->shape[0] == 32 && output_mem_config.shard_spec->shape[1] == 128 && input_shard_num_cores == 8 && output_shard_num_cores == 32) { - log_trace(tt::LogOp, "All conditions matched for LLAMA_POST_BINARY_MATMUL case"); - return AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL; + log_trace(tt::LogOp, "Matching conditions for Llama post SDPA, using LLAMA_MINIMAL_SHARDED implementation"); + return AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED; + } + + // Check for llama rms norm case + if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 1 && input_tensor_shape[2] == 32 && + input_tensor_shape[3] == 32 && input_tensor_memory_config.buffer_type == BufferType::L1 && + output_mem_config.buffer_type == BufferType::L1 && + input_tensor_memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + input_tensor_memory_config.shard_spec->shape[0] == 32 && + input_tensor_memory_config.shard_spec->shape[1] == 32 && output_mem_config.shard_spec->shape[0] == 32 && + output_mem_config.shard_spec->shape[1] == 128 && input_shard_num_cores == 1 && + output_shard_num_cores == 1) { + log_trace( + tt::LogOp, "Matching conditions for Llama rms norm case, using LLAMA_MINIMAL_SHARDED implementation"); + return AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED; } } - log_trace(tt::LogOp, "All conditions matched for generic case"); + log_trace(tt::LogOp, "Using generic implementation"); return AllGatherAsyncVersion::GENERIC; } @@ -207,11 +225,9 @@ tt::tt_metal::operation::ProgramWithCallbacks AllGatherAsync::create_program( this->sub_device_id, this->enable_persistent_fabric_mode); - case AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL: - log_trace( - tt::LogOp, - "Detected all gather specialized shape. all_gather_async_llama_post_binary_matmul is called"); - return all_gather_async_llama_post_binary_matmul( + case AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED: + log_trace(tt::LogOp, "Detected all gather specialized shape. all_gather_async_llama_sharded is called"); + return all_gather_async_llama_sharded( input_tensors[0], this->forward_device, this->backward_device, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index b947120f463..bd66bd923aa 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -28,7 +28,7 @@ using ccl::EriscDatamoverBuilder; enum class AllGatherAsyncVersion { GENERIC = 0, MINIMAL_INTERLEAVED_32 = 1, - LLAMA_POST_BINARY_MATMUL = 2, + LLAMA_MINIMAL_SHARDED = 2, }; struct AllGatherAsync { @@ -141,7 +141,7 @@ tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_minimal_interleav const GlobalSemaphore& semaphore, const std::optional& sub_device_id, bool enable_persistent_fabric_mode); -tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( +tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_sharded( const Tensor& input_tensor, std::optional forward_device, std::optional backward_device, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp index e8191564e2a..4952a3586d0 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp @@ -276,7 +276,7 @@ tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_minimal_interleav return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( +tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_sharded( const Tensor& input_tensor, std::optional forward_device, std::optional backward_device, @@ -292,8 +292,7 @@ tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_post_binary tt::tt_metal::Program program{}; const bool enable_async_output_tensor = false; TT_FATAL( - enable_persistent_fabric_mode, - "only persistent fabric mode is supported for all_gather_async_llama_post_binary_matmul"); + enable_persistent_fabric_mode, "only persistent fabric mode is supported for all_gather_async_llama_sharded"); IDevice* device = input_tensor.device(); bool is_first_chip = ring_index == 0; @@ -386,7 +385,7 @@ tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_post_binary auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" - "llama_post_binary_matmul_shape_reader.cpp", + "llama_shapes_sharded_reader.cpp", sender_worker_core_range, reader_kernel_config); @@ -409,7 +408,7 @@ tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_post_binary auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" - "llama_post_binary_matmul_shape_writer.cpp", + "llama_shapes_sharded_writer.cpp", sender_worker_core_range, writer_kernel_config); @@ -418,14 +417,17 @@ tt::tt_metal::operation::ProgramWithCallbacks all_gather_async_llama_post_binary // semaphore auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); - auto cores_per_device = output_cores_vec.size() / ring_size; + auto cores_per_device = output_cores_vec.size() + ring_size - 1 / ring_size; + uint32_t start_core_index_for_device = output_cores_vec.size() / ring_size * ring_index; + uint32_t end_core_index_for_device = start_core_index_for_device + cores_per_device; TT_FATAL( - output_cores_vec.size() % ring_size == 0, - "output sharded cores must be divisible by num_links for this work distribution scheme"); + output_cores_vec.size() % ring_size == 0 || output_cores_vec.size() == 1, + "output sharded cores ( {} ) must be divisible by num_links ( {} ) or 1 for this work distribution scheme", + output_cores_vec.size(), + ring_size); auto output_cores_this_device = std::vector( - output_cores_vec.begin() + ring_index * cores_per_device, - output_cores_vec.begin() + (ring_index + 1) * cores_per_device); - + output_cores_vec.begin() + start_core_index_for_device, output_cores_vec.begin() + end_core_index_for_device); + log_trace(tt::LogOp, "output_cores_this_device: {}", output_cores_this_device); for (uint32_t link = 0; link < num_links; link++) { CoreCoord core = sender_worker_cores[link]; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_reader.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_reader.cpp diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_writer.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_writer.cpp diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp index fabb376d55c..b7c33163ab0 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp @@ -6,6 +6,7 @@ #include "cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp" #include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +#include "device/all_reduce_async_op.hpp" #include "cpp/ttnn/global_semaphore.hpp" namespace ttnn::operations::experimental::ccl { @@ -106,4 +107,28 @@ ttnn::Tensor ExecuteAllReduceAsync::invoke( true); } +ttnn::Tensor ExecuteAllReduceAsync::invoke( + const ttnn::Tensor& input_tensor, + ttnn::Tensor& buffer_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_preferred_links, + std::optional worker_subdevice_id_opt) { + MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config()); + return ttnn::operations::experimental::ccl::all_reduce_async( + input_tensor, + buffer_tensor, + cluster_axis, + mesh_device, + topology, + multi_device_global_semaphore, + out_memory_config, + num_preferred_links, + worker_subdevice_id_opt, + true); +} + } // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp index 05efd2ba14e..a49f2afcda2 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp @@ -41,6 +41,17 @@ struct ExecuteAllReduceAsync { ttnn::ccl::Topology topology, const std::optional num_preferred_links, std::optional worker_subdevice_id_opt); + + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor, + ttnn::Tensor& buffer_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_preferred_links, + std::optional worker_subdevice_id_opt); }; } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp index ebbed189db9..886b1a3e8a6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp @@ -94,6 +94,39 @@ void bind_all_reduce_async(pybind11::module& module, const ccl_operation_t& oper py::arg("memory_config") = std::nullopt, py::arg("topology") = ttnn::ccl::Topology::Linear, py::arg("num_links") = std::nullopt, + py::arg("subdevice_id") = std::nullopt}, + + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + ttnn::Tensor& buffer_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const ttnn::MemoryConfig& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_links, + std::optional worker_subdevice_id_opt) -> ttnn::Tensor { + return self( + input_tensor, + buffer_tensor, + cluster_axis, + mesh_device, + multi_device_global_semaphore, + memory_config, + topology, + num_links, + worker_subdevice_id_opt); + }, + py::arg("input_tensor"), + py::arg("buffer_tensor"), + py::arg("cluster_axis"), + py::arg("mesh_device"), + py::arg("multi_device_global_semaphore"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("topology") = ttnn::ccl::Topology::Linear, + py::arg("num_links") = std::nullopt, py::arg("subdevice_id") = std::nullopt}); } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp new file mode 100644 index 00000000000..77dc39cc087 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp @@ -0,0 +1,305 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "all_reduce_async_op.hpp" +#include "ttnn/operations/math.hpp" +#include "cpp/ttnn/global_semaphore.hpp" + +#include + +#include "ttnn/tensor/tensor_utils.hpp" + +namespace ttnn { +namespace ccl { +namespace all_reduce_detail { + +AllReduceAsync create_all_reduce_async_struct( + const Tensor& input_tensor, + const uint32_t num_links, + const std::optional& memory_config, + const std::vector& devices, + const ttnn::ccl::Topology topology, + const std::vector& semaphores, + std::optional& sub_device_id, + bool enable_persistent_fabric_mode) { + uint32_t num_devices = devices.size(); + + std::optional forward_device = std::nullopt; + std::optional backward_device = std::nullopt; + std::optional semaphore = std::nullopt; + uint32_t device_index = 0; // Initialize device index + for (uint32_t i = 0; i < num_devices; ++i) { + if (devices.at(i) == input_tensor.device()) { + device_index = i; + semaphore = semaphores.at(i); // Get raw pointer + if (i != 0) { + backward_device = devices.at(i - 1); + } + if (i != num_devices - 1) { + forward_device = devices.at(i + 1); + } + } + } + + return ttnn::AllReduceAsync{ + forward_device, + backward_device, + num_links, + num_devices, + device_index, + memory_config.value_or(input_tensor.memory_config()), + topology, + semaphore.value(), + sub_device_id, + enable_persistent_fabric_mode}; +} + +} // namespace all_reduce_detail +} // namespace ccl + +void AllReduceAsync::validate(const std::vector& input_tensors) const { + TT_FATAL(input_tensors.size() == 2, "Error, Input tensor size should be 2 but has {}", input_tensors.size()); + const auto& input_tensor = input_tensors[0]; + const auto& buffer_tensor = input_tensors[1]; + const auto& layout = input_tensors[0].get_layout(); + const auto& dtype = input_tensors[0].get_dtype(); + const auto& page_size = input_tensors[0].buffer()->page_size(); + TT_FATAL(page_size % input_tensors[0].buffer()->alignment() == 0, "All Gather currently requires aligned pages"); + TT_FATAL( + this->ring_size % 2 == 0, + "AllReduceAsync currently only supports even number of blocks in the reduction kernel."); + + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_reduce need to be on device!"); + TT_FATAL(input_tensor.buffer() != nullptr, "Operands to all_reduce need to be allocated in buffers on device!"); + + TT_FATAL(buffer_tensor.storage_type() == StorageType::DEVICE, "Operands to all_reduce need to be on device!"); + TT_FATAL(buffer_tensor.buffer() != nullptr, "Operands to all_reduce need to be allocated in buffers on device!"); + + TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links); + TT_FATAL( + this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, + "Worker cores used by links are parallelizaed over rows"); + + TT_FATAL( + input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, + "Unsupported memory layout for input tensor{}.", + input_tensor.memory_config().memory_layout); + + TT_FATAL( + buffer_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, + "Unsupported memory layout for buffer tensor {}.", + buffer_tensor.memory_config().memory_layout); + TT_FATAL( + this->output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED, + "Unsupported memory layout for output tensor {}.", + this->output_mem_config.memory_layout); + + TT_FATAL( + buffer_tensor.memory_config().shard_spec->grid.contains(this->output_mem_config.shard_spec->grid), + "The output tensor must reside on a subset of the cores of the buffer tensor"); + + const uint32_t output_shard_shape_volume = + this->output_mem_config.shard_spec->shape[0] * this->output_mem_config.shard_spec->shape[1]; + const uint32_t buffer_shard_shape_volume = + buffer_tensor.memory_config().shard_spec->shape[0] * buffer_tensor.memory_config().shard_spec->shape[1]; + TT_FATAL( + output_shard_shape_volume * this->ring_size <= buffer_shard_shape_volume, + "The shard size for the buffer must be large enough to hold the intermediate tensor. Require at least {} but " + "has {}", + output_shard_shape_volume * this->ring_size, + buffer_shard_shape_volume); +} + +std::vector AllReduceAsync::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors[0]; + auto shape = input_tensor.get_logical_shape(); + auto output_tensor_layout = + input_tensor.get_tensor_spec().tensor_layout().with_memory_config(this->output_mem_config); + return {TensorSpec(shape, output_tensor_layout)}; +} + +tt::tt_metal::operation::ProgramWithCallbacks AllReduceAsync::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + tt::log_debug(tt::LogOp, "DEBUG: create_program is called"); + + auto input_tensor_shape = input_tensors[0].get_padded_shape(); + auto input_tensor_buffer_layout = input_tensors[0].buffer()->buffer_layout(); + auto input_tensor_page_layout = input_tensors[0].layout(); + + auto input_tensor_memory_config = input_tensors[0].memory_config(); + auto output_tensor_memory_config = output_tensors[0].memory_config(); + uint32_t input_shard_num_cores = input_tensor_memory_config.shard_spec->grid.num_cores(); + uint32_t output_shard_num_cores = output_tensor_memory_config.shard_spec->grid.num_cores(); + + tt::log_debug(tt::LogOp, "input_tensor_shape: {}", input_tensor_shape); + tt::log_debug(tt::LogOp, "input_tensor_memory_config: {}", input_tensor_memory_config); + tt::log_debug(tt::LogOp, "output_tensor_memory_config: {}", output_tensor_memory_config); + tt::log_debug(tt::LogOp, "input_shard_num_cores: {}", input_shard_num_cores); + tt::log_debug(tt::LogOp, "output_shard_num_cores: {}", output_shard_num_cores); + tt::log_debug( + tt::LogOp, "input_tensor_memory_config.shard_spec->shape: {}", input_tensor_memory_config.shard_spec->shape); + tt::log_debug( + tt::LogOp, "output_tensor_memory_config.shard_spec->shape: {}", output_tensor_memory_config.shard_spec->shape); + + tt::log_debug(tt::LogOp, "Running TG Llama specific all_reduce_async_minimal_multi_core_with_workers"); + return all_reduce_async_minimal_multi_core_with_workers( + input_tensors[0], + input_tensors[1], + this->forward_device, + this->backward_device, + output_tensors[0], + this->num_links, + this->ring_size, + this->ring_index, + this->topology, + this->semaphore, + this->sub_device_id, + this->enable_persistent_fabric_mode); +} + +const tt::tt_metal::operation::Hash AllReduceAsync::compute_program_hash( + const std::vector& input_tensors) const { + auto input_shape = input_tensors[0].get_padded_shape(); + auto input_memory_layout = input_tensors[0].get_layout(); + auto input_dtype = input_tensors[0].get_dtype(); + auto input_memory_config = input_tensors[0].memory_config(); + + return operation::hash_operation( + this->num_links, + this->ring_size, + this->ring_index, + this->output_mem_config, + this->topology, + input_shape, + input_memory_layout, + input_dtype, + input_memory_config); +} + +namespace operations { +namespace experimental { +namespace ccl { + +Tensor all_reduce_async( + const Tensor& input_tensor, + Tensor& buffer_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config, + const std::optional num_preferred_links, + std::optional subdevice_id, + bool enable_persistent_fabric_mode) { + TT_FATAL( + topology == ttnn::ccl::Topology::Linear, + "This all_reduce API with cluster_axis is currently supported only for the Linear topology"); + const auto mesh_view = mesh_device.get_view(); + auto devices = input_tensor.get_workers(); + std::size_t num_devices = (cluster_axis == 0) ? mesh_view.num_rows() : mesh_view.num_cols(); + + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + std::vector semaphores = multi_device_global_semaphore.global_semaphores; + + operation::launch_op( + [num_preferred_links, + memory_config, + mesh_view, + cluster_axis, + num_devices, + topology, + semaphores, + subdevice_id, + enable_persistent_fabric_mode]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& input_device_tensor = input_tensors.at(0); + + TT_FATAL( + mesh_view.is_mesh_2d(), + "all-gather invoked with cluster_axis API on >2D mesh, which is currently unsupported"); + const auto coordinate = mesh_view.find_device(input_device_tensor.device()->id()); + std::vector devices = (cluster_axis == 0) ? mesh_view.get_devices_on_column(coordinate[1]) + : mesh_view.get_devices_on_row(coordinate[0]); + + const auto& input_tensor = input_tensors.at(0); + const auto& buffer_tensor = input_tensors.at(1); + + return operation::run( + ttnn::ccl::all_reduce_detail::create_all_reduce_async_struct( + input_device_tensor, + num_preferred_links.has_value() ? num_preferred_links.value() : 1, + memory_config, + devices, + topology, + semaphores, + subdevice_id, + enable_persistent_fabric_mode), + {input_tensor, buffer_tensor}); + }, + {input_tensor, buffer_tensor}, + output_tensors); + return output_tensors.at(0); +} + +} // namespace ccl +} // namespace experimental +} // namespace operations + +std::tuple> choose_worker_cores( + size_t num_links, + size_t num_workers_per_link, + bool persistent_fabric_mode, + IDevice* device, + const std::optional& sub_device_id, + const std::optional& reserved_core_range) { + std::tuple> result; + CoreRangeSet sender_worker_core_range; + if (persistent_fabric_mode) { + const size_t num_workers_preferred = num_workers_per_link * num_links; + auto available_cores = device->worker_cores( + tt::tt_metal::HalProgrammableCoreType::TENSIX, + sub_device_id.has_value() ? *sub_device_id : device->get_sub_device_ids().at(0)); + if (reserved_core_range.has_value()) { + available_cores = available_cores.subtract(*reserved_core_range); + } + if (available_cores.num_cores() < num_workers_preferred) { + log_warning( + tt::LogOp, + "AllGather is being launched on a subdevice with fewer worker cores available than ideal. Ideally {} " + "cores ({} per link and {} links) are made available but only {} are available. This may lead to " + "performance loss.", + num_workers_preferred, + num_workers_per_link, + num_links, + available_cores.num_cores()); + } + for (const auto& cr : available_cores.ranges()) { + auto start = cr.start_coord; + auto end = cr.end_coord; + for (size_t y = start.y; y <= end.y; y++) { + for (size_t x = start.x; x <= end.x; x++) { + sender_worker_core_range = + sender_worker_core_range.merge(CoreRangeSet(CoreRange(CoreCoord(x, y), CoreCoord(x, y)))); + if (sender_worker_core_range.num_cores() == num_workers_preferred) { + break; + } + } + if (sender_worker_core_range.num_cores() == num_workers_preferred) { + break; + } + } + if (sender_worker_core_range.num_cores() == num_workers_preferred) { + break; + } + } + } else { + sender_worker_core_range = + CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(num_workers_per_link - 1, num_links - 1))); + } + return {sender_worker_core_range, corerange_to_cores(sender_worker_core_range, std::nullopt, true)}; +} + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp new file mode 100644 index 00000000000..a826ee761b3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" +#include +#include "cpp/ttnn/global_semaphore.hpp" + +#include "ttnn/run_operation.hpp" + +#include +#include + +namespace ttnn { + +using ccl::EriscDatamoverBuilder; + +struct AllReduceAsync { + std::optional forward_device; + std::optional backward_device; + const uint32_t num_links; + const uint32_t ring_size; + const uint32_t ring_index; + const MemoryConfig output_mem_config; + const ccl::Topology topology; + const GlobalSemaphore semaphore; + std::optional sub_device_id; + bool enable_persistent_fabric_mode; + + AllReduceAsync( + std::optional forward_device, + std::optional backward_device, + uint32_t num_links, + uint32_t ring_size, + uint32_t ring_index, + MemoryConfig output_mem_config, + ccl::Topology topology, + GlobalSemaphore semaphore, + std::optional& sub_device_id, + bool enable_persistent_fabric_mode) : + forward_device(forward_device), + backward_device(backward_device), + num_links(num_links), + ring_size(ring_size), + ring_index(ring_index), + output_mem_config(output_mem_config), + topology(topology), + semaphore(semaphore), + sub_device_id(sub_device_id), + enable_persistent_fabric_mode(enable_persistent_fabric_mode) {} + + // Add attributes method for reflection + auto attributes() const { + using tt::stl::reflection::Attribute; + std::vector> attrs; + + attrs.emplace_back("num_links", num_links); + attrs.emplace_back("ring_size", ring_size); + attrs.emplace_back("ring_index", ring_index); + attrs.emplace_back("output_mem_config", output_mem_config); + attrs.emplace_back("topology", topology); + attrs.emplace_back("semaphore", semaphore); + + return attrs; + } + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; + tt::tt_metal::operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + const tt::tt_metal::operation::Hash compute_program_hash(const std::vector& input_tensors) const; +}; + +namespace ccl { +namespace all_reduce_async_detail { +AllReduceAsync create_all_reduce_async_struct( + const Tensor& input_tensor, + const uint32_t num_links, + const std::optional& memory_config, + const std::vector& devices, + const ccl::Topology topology, + const std::vector& semaphores, + std::optional sub_device_id, + bool enable_persistent_fabric_mode); + +} // namespace all_reduce_async_detail +} // namespace ccl + +std::tuple> choose_worker_cores( + size_t num_links, + size_t num_workers_per_link, + bool persistent_fabric_mode, + IDevice* device, + const std::optional& sub_device_id, + const std::optional& reserved_core_range = std::nullopt); + +tt::tt_metal::operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers( + const Tensor& input_tensor, + const Tensor& buffer_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore& semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode); + +namespace operations { +namespace experimental { +namespace ccl { + +Tensor all_reduce_async( + const Tensor& input_tensor, + Tensor& buffer_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config = std::nullopt, + const std::optional num_preferred_links = std::nullopt, + std::optional sub_device_id = std::nullopt, + bool enable_persistent_fabric_mode = false); + +} // namespace ccl +} // namespace experimental +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp new file mode 100644 index 00000000000..54372f82bc3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -0,0 +1,493 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +/// +#include + +#include +#include +#include +#include "ttnn/tensor/tensor_impl.hpp" +#include "all_reduce_async_op.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/math.hpp" +#include +#include +#include +#include +#include "cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" + +#include "cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp" + +#include "cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp" +#include +#include +#include +#include +using namespace tt::constants; + +namespace ttnn { + +using namespace ccl; + +CoreRangeSet cores_to_corerangeset(const std::vector& cores) { + std::vector core_ranges; + for (const auto& core : cores) { + core_ranges.push_back(CoreRange(core)); + } + return CoreRangeSet(core_ranges); +} + +tt::tt_metal::operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers( + const Tensor& input_tensor, + const Tensor& buffer_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore& semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode) { + tt::tt_metal::Program program{}; + + IDevice* device = input_tensor.device(); + bool is_first_chip = ring_index == 0; + bool is_last_chip = ring_index == ring_size - 1; + log_trace( + tt::LogOp, + "DEBUG: device: {}, is_first_chip: {}, is_last_chip: {}", + input_tensor.device()->id(), + is_first_chip, + is_last_chip); + + std::optional local_fabric_handle = + ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + &program, + enable_persistent_fabric_mode, + num_links); + + // Get OP Config, topology config + std::vector input_tensors = {input_tensor}; + std::vector output_tensors = {output_tensor}; + const auto& op_config = ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + LineTopology line_topology(ring_size, ring_index); + const size_t num_targets_forward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); + const size_t num_targets_backward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); + // Tensor Info + const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); + const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; + const auto input_tensor_shard_shape = input_tensor.memory_config().shard_spec->shape; + const auto input_tensor_shard_num_pages = input_tensor_shard_shape[0] * input_tensor_shard_shape[1] / TILE_HW; + const auto num_input_cores = input_tensor_cores.num_cores(); + const auto output_tensor_num_pages = output_tensor.buffer()->num_pages(); + const auto output_tensor_cores = output_tensor.memory_config().shard_spec->grid; + const auto output_tensor_shard_shape = output_tensor.memory_config().shard_spec->shape; + const auto output_tensor_shard_num_pages = output_tensor_shard_shape[0] * output_tensor_shard_shape[1] / TILE_HW; + const auto num_output_cores = output_tensor_cores.num_cores(); + + // Get worker cores, assuming 1 worker per link + std::optional reserved_cores = output_tensor_cores; + uint32_t num_workers_per_link = 1; + const auto [sender_worker_core_range, sender_worker_cores] = choose_worker_cores( + num_links, num_workers_per_link, enable_persistent_fabric_mode, device, sub_device_id, reserved_cores); + + tt::log_debug(tt::LogOp, "input_tensor_num_pages: {}", input_tensor_num_pages); + tt::log_debug(tt::LogOp, "input_tensor_cores: {}", input_tensor_cores); + tt::log_debug(tt::LogOp, "input_tensor_shard_shape: {}", input_tensor_shard_shape); + tt::log_debug(tt::LogOp, "input_tensor_shard_num_pages: {}", input_tensor_shard_num_pages); + tt::log_debug(tt::LogOp, "output_tensor_cores: {}", output_tensor_cores); + tt::log_debug(tt::LogOp, "output_tensor_shard_shape: {}", output_tensor_shard_shape); + tt::log_debug(tt::LogOp, "output_tensor_shard_num_pages: {}", output_tensor_shard_num_pages); + + // L1 Scratch CB Creation + const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes(); + uint32_t l1_scratch_cb_page_size_bytes = op_config.get_page_size(); + uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; + uint32_t cb_num_pages = input_tensor_num_pages; // TODO: Reduce this to double-buffer packet-size? + uint32_t src0_cb_index = tt::CBIndex::c_0; + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_num_pages * l1_scratch_cb_page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, l1_scratch_cb_page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); + // Set aside a buffer we can use for storing packet headers in (particularly for atomic incs) + const auto reserved_packet_header_CB_index = tt::CBIndex::c_3; + static constexpr auto num_packet_headers_storable = 8; + static constexpr auto packet_header_size_bytes = sizeof(tt::fabric::PacketHeader); + tt::tt_metal::CircularBufferConfig cb_reserved_packet_header_config = + tt::tt_metal::CircularBufferConfig( + num_packet_headers_storable * packet_header_size_bytes * 2, + {{reserved_packet_header_CB_index, tt::DataFormat::RawUInt32}}) + .set_page_size(reserved_packet_header_CB_index, packet_header_size_bytes); + auto reserved_packet_header_CB_handle = + CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); + + // Reduction kernel setup + auto all_cores = output_tensor_cores.merge(sender_worker_core_range); + auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); + auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); + + // Create output tensor splits + // TODO: Currently does not support output shards being split across multiple links + std::vector output_corerangeset_per_link; + std::vector num_output_cores_in_link(num_links, 0); + uint32_t output_cores_per_link = tt::div_up(output_tensor_cores.num_cores(), num_links); + uint32_t num_assigned_cores = 0; + for (uint32_t link = 0; link < num_links; link++) { + uint32_t num_cores_this_link = std::min(output_cores_per_link, num_output_cores - num_assigned_cores); + output_corerangeset_per_link.emplace_back( + cores_to_corerangeset(std::vector( + output_cores_vec.begin() + num_assigned_cores, + output_cores_vec.begin() + num_assigned_cores + num_cores_this_link)) + .merge_ranges()); + num_output_cores_in_link[link] = num_cores_this_link; + num_assigned_cores += num_cores_this_link; + } + + // Create output tensor page splits + std::vector output_tensor_pages_in_link(num_links, 0); + uint32_t num_assigned_pages = 0; + for (uint32_t link = 0; link < num_links; link++) { + uint32_t num_output_pages_per_link = output_tensor_shard_num_pages * num_output_cores_in_link[link]; + uint32_t num_pages_this_link = + std::min(num_output_pages_per_link, output_tensor_num_pages - num_assigned_pages); + output_tensor_pages_in_link[link] = num_pages_this_link; + num_assigned_pages += num_pages_this_link; + } + + // Create input tensor splits + /* + Overview of algorithm: + + - Ouput: each link gets assigned a start and end core index, since multiple links + may have to read different offesets within a shard on the same core + - First, assign all the necessary cores needed for a link. This may result in the link + containing extra pages. This will result in an overflow, which is used to detect + the tile offset (within a shard) for the next link + - Once you have the start_core_idx, the end_core_idx is calculated by + getting the upper bound on the number of cores needed to read the pages assigned + to the link, accounting for the tile offset. This calculation is done by dividing + the upper bound on the number of pages assigned to this link + (num_pages_this_link + input_tensor_tile_offset) by the number of pages in a shard. + This gives the number of cores needed to read the pages assigned to this link. + - If an overflow is detected, then the start_core_idx for the next link is set + to the end_core_idx of the current link. Ie, 2 links read from the same core + */ + std::vector> input_cores_idx_per_link(num_links, {0, 0}); + std::vector input_tensor_tile_offset_per_link(num_links, 0); + uint32_t start_core_idx = 0; + uint32_t num_pages_overflow = 0; + for (uint32_t link = 0; link < num_links; link++) { + uint32_t num_pages_this_link = output_tensor_pages_in_link[link]; + + // Get offset based on previous overflow + uint32_t input_tensor_tile_offset = + (input_tensor_shard_num_pages - num_pages_overflow) % input_tensor_shard_num_pages; + input_tensor_tile_offset_per_link[link] = input_tensor_tile_offset; + + uint32_t end_core_idx = std::min( + start_core_idx + tt::div_up(num_pages_this_link + input_tensor_tile_offset, input_tensor_shard_num_pages), + num_input_cores); + + // Num pages allocated based on number of input cores selected for this link + uint32_t num_pages_allocated = + (end_core_idx - start_core_idx) * input_tensor_shard_num_pages - input_tensor_tile_offset; + + // Update overflow + num_pages_overflow = num_pages_allocated - num_pages_this_link; + + // Store core indices + input_cores_idx_per_link[link] = {start_core_idx, end_core_idx}; + + // Set start index based on overflow + if (num_pages_overflow > 0) { + start_core_idx = end_core_idx - 1; + } else { + start_core_idx = end_core_idx; + } + } + + // Create reduction semaphores for each link + std::vector reduction_semaphore_ids(num_links, 0); + for (uint32_t link = 0; link < num_links; link++) { + reduction_semaphore_ids[link] = tt::tt_metal::CreateSemaphore(program, all_cores, 0); + } + + /* reduction cb */ + uint32_t reduction_CB_single_tile_size = output_tensor.get_tensor_spec().tile().get_tile_size(df); + uint32_t reduction_CB_tiles = output_tensor_shard_num_pages * ring_size; + uint32_t reduction_CB_size = reduction_CB_tiles * reduction_CB_single_tile_size; + + uint32_t reduction_cb_index = tt::CBIndex::c_1; + tt::tt_metal::CircularBufferConfig reduction_cb_config = + tt::tt_metal::CircularBufferConfig(reduction_CB_size, {{reduction_cb_index, df}}) + .set_page_size(reduction_cb_index, reduction_CB_single_tile_size) + .set_globally_allocated_address(*buffer_tensor.buffer()); + auto cb_reduction = tt::tt_metal::CreateCircularBuffer(program, all_cores, reduction_cb_config); + + /* out cb */ + uint32_t out_CB_single_tile_size = output_tensor.get_tensor_spec().tile().get_tile_size(df); + uint32_t out_CB_tiles = output_tensor_shard_num_pages; + uint32_t out_CB_size = out_CB_tiles * out_CB_single_tile_size; + + uint32_t out_cb_index = tt::CBIndex::c_2; + tt::tt_metal::CircularBufferConfig out_cb_config = + tt::tt_metal::CircularBufferConfig(out_CB_size, {{out_cb_index, df}}) + .set_page_size(out_cb_index, out_CB_single_tile_size) + .set_globally_allocated_address(*output_tensor.buffer()); // TODO: Remove once new cb attached for output + auto cb_out = tt::tt_metal::CreateCircularBuffer( + program, output_tensor_cores, out_cb_config); // TODO: This should be the output cores instead + + // Create reduction dataflow kernel + auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; + reduction_reader_kernel_config.compile_args = { + reduction_cb_index, // reduction_cb_index + reduction_CB_tiles, // total_num_reduction_tiles + }; + auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/" + "reduction_receiver.cpp", + output_tensor_cores, + reduction_reader_kernel_config); + + // Create reduction dataflow kernel + auto reduction_kernel_config = tt::tt_metal::ComputeConfig{}; + reduction_kernel_config.compile_args = { + reduction_cb_index, // reduction_cb_index + out_cb_index, // out_cb_index + }; + auto reduction_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/" + "reduction.cpp", + output_tensor_cores, + reduction_kernel_config); + std::vector reduction_kernel_rt_args = { + ring_size, // num_blocks + output_tensor_shard_num_pages, // block_num_tiles + }; + tt::tt_metal::SetRuntimeArgs(program, reduction_kernel_id, output_tensor_cores, reduction_kernel_rt_args); + + // KERNEL CREATION + tt::tt_metal::NOC reader_noc = NOC::NOC_1; + tt::tt_metal::NOC writer_noc = NOC::NOC_0; + // Reader + std::vector reader_compile_args = { + ring_index, // my_chip_id + src0_cb_index, // cb0_id + op_config.get_page_size(), // tensor0_page_size + }; + log_trace(tt::LogOp, "Reader Compile Args:"); + auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/" + "worker_reader.cpp", + sender_worker_core_range, + DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_1, .noc = reader_noc, .compile_args = reader_compile_args}); + + // Writer + std::vector writer_compile_args = { + ring_index, // my_chip_id + reserved_packet_header_CB_index, // reserved_packet_header_cb_id + num_packet_headers_storable, // num_packet_headers_storable + src0_cb_index, // cb0_id + num_pages_per_packet, // packet_size_in_pages + op_config.get_page_size(), // tensor0_page_size + num_targets_forward, // num_targets_forward_direction + num_targets_backward, // num_targets_backward_direction + }; + log_trace(tt::LogOp, "Writer Compile Args:"); + auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/" + "worker_writer.cpp", + sender_worker_core_range, + DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_0, .noc = writer_noc, .compile_args = writer_compile_args}); + + // Kernel Runtime Args + for (uint32_t link = 0; link < num_links; link++) { + CoreCoord core = sender_worker_cores[link]; + CoreCoord drain_sync_core = device->worker_core_from_logical_core(core); + uint32_t worker_num_tiles_to_read = output_tensor_pages_in_link[link]; + + uint32_t input_first_core_tile_start_offset = input_tensor_tile_offset_per_link[link]; + uint32_t output_first_core_tile_start_offset = 0; + + std::vector input_tensor_cores_x; + std::vector input_tensor_cores_y; + std::vector output_tensor_cores_x; + std::vector output_tensor_cores_y; + for (uint32_t i = input_cores_idx_per_link[link].first; i < input_cores_idx_per_link[link].second; i++) { + auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); + input_tensor_cores_x.push_back(this_core.x); + input_tensor_cores_y.push_back(this_core.y); + } + for (uint32_t i = output_cores_per_link * link; + i < output_cores_per_link * link + num_output_cores_in_link[link]; + i++) { + auto this_core = device->worker_core_from_logical_core(output_cores_vec[i]); + output_tensor_cores_x.push_back(this_core.x); + output_tensor_cores_y.push_back(this_core.y); + } + + std::optional forward_fabric_connection = + line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::FORWARD)); + std::optional backward_fabric_connection = + line_topology.is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::BACKWARD)); + + // Set reader runtime args + std::vector reader_rt_args = { + input_tensor.buffer()->address(), // tensor_address0 + input_tensor_shard_num_pages, // num_tiles_per_core + worker_num_tiles_to_read, // num_tiles_to_read + input_first_core_tile_start_offset, // first_core_tile_start_offset + input_tensor_cores_x.size(), // num_cores + }; + reader_rt_args.insert(reader_rt_args.end(), input_tensor_cores_x.begin(), input_tensor_cores_x.end()); + reader_rt_args.insert(reader_rt_args.end(), input_tensor_cores_y.begin(), input_tensor_cores_y.end()); + log_trace(tt::LogOp, "Reader Runtime Args:"); + for (const auto& arg : reader_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); + + // Set writer runtime args + std::vector mcast_start_x; + std::vector mcast_start_y; + std::vector mcast_end_x; + std::vector mcast_end_y; + + uint32_t num_mcast_cores = 0; + for (const auto& range : output_corerangeset_per_link[link].ranges()) { + auto start_core = device->worker_core_from_logical_core(range.start_coord); + auto end_core = device->worker_core_from_logical_core(range.end_coord); + num_mcast_cores += (end_core.x - start_core.x + 1) * (end_core.y - start_core.y + 1); + bool mcast_range_contains_self = + start_core.x <= core.x && core.x <= end_core.x && start_core.y <= core.y && core.y <= end_core.y; + if (mcast_range_contains_self) { + num_mcast_cores -= 1; + } + if (writer_noc == NOC::NOC_1) { + std::swap(start_core, end_core); + } + mcast_start_x.push_back(start_core.x); + mcast_start_y.push_back(start_core.y); + mcast_end_x.push_back(end_core.x); + mcast_end_y.push_back(end_core.y); + } + + uint32_t out_ready_sem_wait_value = ring_size; + std::vector writer_rt_args = { + reduction_cb_index, // tensor_address0 + semaphore.address(), // out_ready_sem_bank_addr (absolute address) + output_tensor_shard_num_pages, // num_tiles_per_core + worker_num_tiles_to_read, // num_tiles_to_read + output_first_core_tile_start_offset, // first_core_tile_start_offset + output_tensor_cores_x.size(), // num_cores + num_mcast_cores, // num_mcast_cores + drain_sync_core.x, // out_ready_sem_noc0_x + drain_sync_core.y, // out_ready_sem_noc0_y + out_ready_sem_wait_value, // out_ready_sem_wait_value + reduction_semaphore_ids[link], // reduction_semaphore_id + mcast_start_x.size(), // num_mcast_ranges + link, // link + }; + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_x.begin(), output_tensor_cores_x.end()); + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_y.begin(), output_tensor_cores_y.end()); + + writer_rt_args.insert(writer_rt_args.end(), mcast_start_x.begin(), mcast_start_x.end()); + writer_rt_args.insert(writer_rt_args.end(), mcast_start_y.begin(), mcast_start_y.end()); + writer_rt_args.insert(writer_rt_args.end(), mcast_end_x.begin(), mcast_end_x.end()); + writer_rt_args.insert(writer_rt_args.end(), mcast_end_y.begin(), mcast_end_y.end()); + + log_trace(tt::LogOp, "Writer Runtime Args:"); + for (const auto& arg : writer_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + writer_rt_args.push_back(forward_fabric_connection.has_value()); + if (forward_fabric_connection.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, {core}, 0); + append_worker_to_fabric_edm_sender_rt_args( + forward_fabric_connection.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_teardown_semaphore_id, + sender_worker_buffer_index_semaphore_id, + writer_rt_args); + } + writer_rt_args.push_back(backward_fabric_connection.has_value()); + if (backward_fabric_connection.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, {core}, 0); + append_worker_to_fabric_edm_sender_rt_args( + backward_fabric_connection.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_teardown_semaphore_id, + sender_worker_buffer_index_semaphore_id, + writer_rt_args); + } + tt::tt_metal::SetRuntimeArgs(program, worker_sender_writer_kernel_id, {core}, writer_rt_args); + + // Set reduction worker runtime args + std::vector reduction_reader_rt_args = { + reduction_semaphore_ids[link], // reduction_semaphore_id + }; + tt::tt_metal::SetRuntimeArgs( + program, reduction_reader_kernel_id, output_corerangeset_per_link[link], reduction_reader_rt_args); + } + + auto override_runtime_arguments_callback = + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, sender_worker_cores, cb_out, cb_reduction]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; + const auto& buffer_tensor = input_tensors[1]; + + auto semaphore = static_cast(operation)->semaphore; + + // update senders + auto& worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); + auto& worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); + for (const auto& core : sender_worker_cores) { + // reader + auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; + worker_reader_sender_runtime_args[0] = input.buffer()->address(); + // writer + auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; + worker_writer_sender_runtime_args[1] = semaphore.address(); + } + UpdateDynamicCircularBufferAddress(program, cb_out, *output.buffer()); + UpdateDynamicCircularBufferAddress(program, cb_reduction, *buffer_tensor.buffer()); + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp new file mode 100644 index 00000000000..67336cf3288 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/eltwise_binary.h" + +namespace NAMESPACE { +void MAIN { + constexpr uint32_t cb_in0 = get_compile_time_arg_val(0); + constexpr uint32_t cb_out0 = get_compile_time_arg_val(1); + constexpr uint32_t cb_in1 = cb_in0; + + uint32_t rt_args_idx = 0; + const uint32_t num_blocks = get_arg_val(rt_args_idx++); + const uint32_t block_num_tiles = get_arg_val(rt_args_idx++); + const uint32_t copy_first_block = num_blocks % 2 != 0; + + constexpr uint32_t max_dst_tiles = 8; // TODO: Make general + + cb_wait_front(cb_in0, num_blocks * block_num_tiles); + cb_reserve_back(cb_out0, block_num_tiles); + + binary_op_init_common(cb_in0, cb_in1, cb_out0); + add_tiles_init(cb_in0, cb_in1, true); + + uint32_t num_pack_iters = (block_num_tiles + max_dst_tiles - 1) / max_dst_tiles; + uint32_t block_num_tiles_cnt = 0; + + for (uint32_t p = 0; p < num_pack_iters; ++p) { + uint32_t num_tiles_to_pack = std::min(max_dst_tiles, block_num_tiles - block_num_tiles_cnt); + tile_regs_acquire(); + for (uint32_t block = 0; block < num_blocks; block += 2) { + if (copy_first_block && block == 0) { + // TODO: Future support + } else { + for (uint32_t i = 0; i < num_tiles_to_pack; ++i) { + add_tiles( + cb_in0, + cb_in1, + block * block_num_tiles + p * max_dst_tiles + i, + (block + 1) * block_num_tiles + p * max_dst_tiles + i, + i); + } + } + } + tile_regs_commit(); + + // Pack output tiles + tile_regs_wait(); + for (uint32_t i = 0; i < num_tiles_to_pack; ++i) { + pack_tile(i, cb_out0, p * max_dst_tiles + i); + } + tile_regs_release(); + + block_num_tiles_cnt += num_tiles_to_pack; + } + + cb_push_back(cb_out0, block_num_tiles); +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp new file mode 100644 index 00000000000..ca63befeea9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + constexpr uint32_t cb_id = get_compile_time_arg_val(0); + constexpr uint32_t total_num_reduction_tiles = get_compile_time_arg_val(1); + + // runtime args + size_t arg_idx = 0; + const uint32_t signal_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); + + volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = + reinterpret_cast(signal_semaphore_addr); + + // 1. Wait for signal from All-Gather worker + noc_semaphore_wait(signal_semaphore_addr_ptr, VALID); + noc_semaphore_set(signal_semaphore_addr_ptr, 0); + + // 2. Signal compute kernel to start processing + cb_push_back(cb_id, total_num_reduction_tiles); +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp new file mode 100644 index 00000000000..3caa208cfd4 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include +#include + +using address_t = uint32_t; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t cb0_id = get_compile_time_arg_val(1); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(2); + +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); + uint32_t num_tiles_to_read = get_arg_val(arg_idx++); + uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); + uint32_t num_cores = get_arg_val(arg_idx++); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + + // interleaved addrgen + uint32_t tiles_read = 0; + uint32_t shard_tile_id = first_core_tile_start_offset; + uint32_t core_id = 0; + while (tiles_read < num_tiles_to_read) { + uint32_t num_tiles_to_read_this_core = + std::min(num_tiles_per_core - shard_tile_id, num_tiles_to_read - tiles_read); + cb_reserve_back(cb0_id, num_tiles_to_read_this_core); + const uint32_t l1_write_addr = get_write_ptr(cb0_id); + uint64_t read_addr = get_noc_addr(core_noc_x[core_id], core_noc_y[core_id], tensor_address0); + read_addr += shard_tile_id * tensor0_page_size; + + noc_async_read(read_addr, l1_write_addr, num_tiles_to_read_this_core * tensor0_page_size); + noc_async_read_barrier(); + + cb_push_back(cb0_id, num_tiles_to_read_this_core); + tiles_read += num_tiles_to_read_this_core; + shard_tile_id = 0; + core_id++; + } +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp new file mode 100644 index 00000000000..29193212824 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp @@ -0,0 +1,194 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp" +#include +#include + +using address_t = uint32_t; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t reserved_packet_header_cb_id = get_compile_time_arg_val(1); +constexpr uint32_t num_packet_headers_storable = get_compile_time_arg_val(2); +constexpr uint32_t cb0_id = get_compile_time_arg_val(3); +constexpr uint32_t packet_size_in_pages = get_compile_time_arg_val(4); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(5); +constexpr uint32_t num_targets_forward_direction = get_compile_time_arg_val(6); +constexpr uint32_t num_targets_backward_direction = get_compile_time_arg_val(7); + +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + uint32_t reduction_input_cb_id = get_arg_val(arg_idx++); + address_t reduction_input_addr = get_write_ptr(reduction_input_cb_id); + + const size_t out_ready_sem_bank_addr = get_arg_val(arg_idx++); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); + uint32_t num_tiles_to_read = get_arg_val(arg_idx++); + uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); + uint32_t num_cores = get_arg_val(arg_idx++); + uint32_t num_mcast_cores = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); + uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); + const uint32_t reduction_semaphore_send_addr = get_semaphore(get_arg_val(arg_idx++)); + const uint32_t num_mcast_ranges = get_arg_val(arg_idx++); + const uint32_t link = get_arg_val(arg_idx++); + + // Set up for mcasting to reduction workers + volatile tt_l1_ptr uint32_t* reduction_semaphore_send_addr_ptr = + reinterpret_cast(reduction_semaphore_send_addr); + noc_semaphore_set(reduction_semaphore_send_addr_ptr, VALID); + + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + + tt_l1_ptr uint32_t* mcast_dest_noc_start_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_mcast_ranges; + tt_l1_ptr uint32_t* mcast_dest_noc_start_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_mcast_ranges; + tt_l1_ptr uint32_t* mcast_dest_noc_end_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_mcast_ranges; + tt_l1_ptr uint32_t* mcast_dest_noc_end_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_mcast_ranges; + + size_t arg_for_fab = arg_idx; + auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); + + // packet header cb + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_forward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_backward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_seminc = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + + // pre-populate packet headers + volatile PACKET_HEADER_TYPE* pkt_hdr_forward = + reinterpret_cast(packet_header_buffer_addr_forward); + volatile PACKET_HEADER_TYPE* pkt_hdr_backward = + reinterpret_cast(packet_header_buffer_addr_backward); + pkt_hdr_forward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + pkt_hdr_backward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + + if (fabric_connection.is_logically_connected()) { + fabric_connection.open(); + } + + // 1. mcast via fabric to remote tensor addresses + uint32_t tiles_read = 0; + uint32_t shard_tile_id = first_core_tile_start_offset; + uint32_t core_id = 0; + uint32_t writer_chip_offset = my_chip_id * num_tiles_per_core * tensor0_page_size; + + while (tiles_read < num_tiles_to_read) { + uint32_t num_tiles_to_read_this_core = std::min(num_tiles_per_core - shard_tile_id, packet_size_in_pages); + num_tiles_to_read_this_core = std::min(num_tiles_to_read - tiles_read, num_tiles_to_read_this_core); + cb_wait_front(cb0_id, num_tiles_to_read_this_core); + size_t l1_read_addr = get_read_ptr(cb0_id); + + uint64_t noc0_dest_noc_addr = + get_noc_addr(core_noc_x[core_id], core_noc_y[core_id], reduction_input_addr + writer_chip_offset); + + // Within-shard offset + noc0_dest_noc_addr += shard_tile_id * tensor0_page_size; + + write_and_advance_local_read_address_for_fabric_write( + noc0_dest_noc_addr, + pkt_hdr_forward, + pkt_hdr_backward, + fabric_connection, + l1_read_addr, + num_tiles_to_read_this_core * tensor0_page_size); + noc_async_writes_flushed(); + + cb_pop_front(cb0_id, num_tiles_to_read_this_core); + tiles_read += num_tiles_to_read_this_core; + shard_tile_id += num_tiles_to_read_this_core; + if (shard_tile_id >= num_tiles_per_core) { + shard_tile_id = 0; + core_id++; + } + } + + // 2. mcast output ready semaphore + auto* pkt_hdr = reinterpret_cast(packet_header_buffer_seminc); + uint64_t out_ready_sem_noc_addr_in_pkt = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); + pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ + out_ready_sem_noc_addr_in_pkt, + static_cast(1), // increment 1 + 32}); + // Write the mcast packet (forward) + if (fabric_connection.has_forward_connection()) { + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + packet_header_buffer_seminc, sizeof(PACKET_HEADER_TYPE)); + } + // Write the mcast packet (backward) + if (fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( + packet_header_buffer_seminc, sizeof(PACKET_HEADER_TYPE)); + } + // increment locally + uint64_t out_ready_sem_noc_addr = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); + noc_semaphore_inc(out_ready_sem_noc_addr, 1); + + // 3. wait for mcast output ready semaphore + while (*reinterpret_cast(out_ready_sem_bank_addr) != out_ready_sem_wait_value); + + // loop over mcast ranges + for (uint32_t i = 0; i < num_mcast_ranges; i++) { + // Signal the reduction workers + const uint64_t reduction_semaphore_recv_noc_addr = get_noc_multicast_addr( + mcast_dest_noc_start_x[i], + mcast_dest_noc_start_y[i], + mcast_dest_noc_end_x[i], + mcast_dest_noc_end_y[i], + reduction_semaphore_send_addr); + + noc_semaphore_set_multicast( + reduction_semaphore_send_addr, + reduction_semaphore_recv_noc_addr, + i == 0 ? num_mcast_cores : 0, + false, // linked = false + true); // multicast_path_reserve = true + } + + // 4. global semaphore reset + *reinterpret_cast(out_ready_sem_bank_addr) = 0; + + if (fabric_connection.is_logically_connected()) { + fabric_connection.close(); + } + + noc_async_write_barrier(); + + // DPRINT << "writer done \n"; +}