Skip to content

Commit

Permalink
Trying again: Revert Fix Minimal All Reduce for Llama shapes (#18731)
Browse files Browse the repository at this point in the history
### Description
This is a duplicate PR for
#18217 which got reverted
due to symlink issue.
The issue is now fixed and post-commit pipelines are passing.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13702739699)
CI passes
- [x] [TG
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/13688165179)
CI passes

---------

Co-authored-by: avoraTT <avora@tenstorrent.com>
Co-authored-by: yugaoTT <yugao@tenstorrent.com>
  • Loading branch information
3 people authored Mar 6, 2025
1 parent 727cfa4 commit f3d8fac
Show file tree
Hide file tree
Showing 24 changed files with 2,212 additions and 279 deletions.
120 changes: 120 additions & 0 deletions models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from loguru import logger
import ttnn

from models.perf.benchmarking_utils import BenchmarkData, BenchmarkProfiler
from models.perf.device_perf_utils import run_device_perf_detailed


@pytest.mark.parametrize(
"ag_type, warmup_iters, perf_target_us",
[
("sdpa", 10, 11),
("binary_mult", 10, 12),
("layernorm", 10, 8),
],
)
@pytest.mark.models_device_performance_bare_metal
def test_ag_tg_llama_perf(
ag_type,
warmup_iters,
perf_target_us,
):
profiler = BenchmarkProfiler()
benchmark_data = BenchmarkData()
step_name = f"all_gather_{ag_type}"

subdir = "llama_ccl_perf"
command = (
f"pytest tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py::test_all_gather_tg_llama -k {ag_type}"
)
cols = ["DEVICE KERNEL"]
op_name = "AllGatherAsync"
warmup_iters = warmup_iters * 32 # 5 iterations per device

profiler.start("run")
profiler.start(step_name)
results = run_device_perf_detailed(command, subdir, cols, op_name, has_signposts=True, warmup_iters=warmup_iters)
profiler.end(step_name)
profiler.end("run")

# Get the measured performance
measured_min_us = results[cols[0]]["MIN"] / 1000
measured_max_us = results[cols[0]]["MAX"] / 1000
measured_avg_us = results[cols[0]]["AVG"] / 1000
measured_std_us = results[cols[0]]["STD"] / 1000

logger.info(f"Measured performance: {measured_avg_us:.3f} us vs. target: {perf_target_us} us")

# Save the measurement
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-min-us", measured_min_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-max-us", measured_max_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-avg-us", measured_avg_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-std-us", measured_std_us)
benchmark_data.save_partial_run_json(
profiler,
run_type=f"all_gather",
ml_model_name="llama70b-tg-ccl",
)

assert measured_avg_us < perf_target_us, f"Performance target not met: {measured_avg_us} us > {perf_target_us} us"


@pytest.mark.parametrize(
"ar_type, warmup_iters, perf_target_us",
[
("ff2", 10, 29),
("qkv", 10, 25),
("ff1", 10, 30),
("lm_head", 10, 70),
],
)
@pytest.mark.models_device_performance_bare_metal
def test_ar_tg_llama_perf(
ar_type,
warmup_iters,
perf_target_us,
):
profiler = BenchmarkProfiler()
benchmark_data = BenchmarkData()
step_name = f"all_reduce_{ar_type}"

subdir = "llama_ccl_perf"
command = (
f"pytest tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py::test_all_reduce_tg_llama -k {ar_type}"
)
cols = ["DEVICE KERNEL"]
op_name = "AllReduceAsync"
warmup_iters = warmup_iters * 32 # 5 iterations per device

profiler.start("run")
profiler.start(step_name)
results = run_device_perf_detailed(command, subdir, cols, op_name, has_signposts=True, warmup_iters=warmup_iters)
profiler.end(step_name)
profiler.end("run")

# Get the measured performance
measured_min_us = results[cols[0]]["MIN"] / 1000
measured_max_us = results[cols[0]]["MAX"] / 1000
measured_avg_us = results[cols[0]]["AVG"] / 1000
measured_std_us = results[cols[0]]["STD"] / 1000

logger.info(f"Measured performance: {measured_avg_us:.3f} us vs. target: {perf_target_us} us")

# Save the measurement
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-min-us", measured_min_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-max-us", measured_max_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-avg-us", measured_avg_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-std-us", measured_std_us)
benchmark_data.save_partial_run_json(
profiler,
run_type=f"all_reduce",
ml_model_name="llama70b-tg-ccl",
)

assert measured_avg_us < perf_target_us, f"Performance target not met: {measured_avg_us} us > {perf_target_us} us"
77 changes: 77 additions & 0 deletions models/perf/device_perf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import json
import time
import pandas as pd

from loguru import logger
from collections import defaultdict

from tt_metal.tools.profiler.common import clear_profiler_runtime_artifacts
from tt_metal.tools.profiler.process_model_log import (
get_latest_ops_log_filename,
post_process_ops_log,
run_device_profiler,
get_samples_per_s,
Expand Down Expand Up @@ -49,6 +53,79 @@ def run_device_perf(command, subdir, num_iterations, cols, batch_size, has_signp
return post_processed_results


# TODO: Move into process_model_log.py (#18698)
def post_process_ops_log_detailed(
output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False, detailed=False, warmup_iters=0
):
filename = get_latest_ops_log_filename(output_logs_subdir)
df = pd.read_csv(filename)

if has_signposts:
# there are explicit start and stop points in the model we want to measure between
markers = df[df["OP TYPE"] == "signpost"]["OP CODE"]
start = markers[markers == "start"].index[0]
stop = markers[markers == "stop"].index[0]
df = df.iloc[start + 1 : stop]
if op_name != "":
df = df[df["OP CODE"] == op_name]

if warmup_iters > 0:
df = df.iloc[warmup_iters:]

results = {}
for col in columns:
df_filtered = df[df[col] != "-"]
if sum_vals:
results[col] = df_filtered[col].astype(float).sum()
else:
results[col] = df_filtered[col].astype(float).to_numpy()

if detailed:
results[f"AVG {col}"] = df_filtered[col].astype(float).mean()
results[f"MIN {col}"] = df_filtered[col].astype(float).min()
results[f"MAX {col}"] = df_filtered[col].astype(float).max()
results[f"STD {col}"] = df_filtered[col].astype(float).std()

return results


def run_device_perf_detailed(command, subdir, cols, op_name="", has_signposts=False, warmup_iters=0):
duration_cols = [col + " DURATION [ns]" for col in cols]

clear_profiler_runtime_artifacts()

results = {}
for d_col in duration_cols:
results[f"AVG {d_col}"] = 0
results[f"MIN {d_col}"] = float("inf")
results[f"MAX {d_col}"] = -float("inf")
results[f"STD {d_col}"] = 0

run_device_profiler(command, subdir)
r = post_process_ops_log_detailed(
subdir, duration_cols, op_name=op_name, has_signposts=has_signposts, detailed=True, warmup_iters=warmup_iters
)
for d_col in duration_cols:
results[f"AVG {d_col}"] = r[f"AVG {d_col}"]
results[f"MIN {d_col}"] = r[f"MIN {d_col}"]
results[f"MAX {d_col}"] = r[f"MAX {d_col}"]
results[f"STD {d_col}"] = r[f"STD {d_col}"]

post_processed_results = defaultdict(dict)
for col, d_col in zip(cols, duration_cols):
post_processed_results[col]["AVG"] = results[f"AVG {d_col}"]
post_processed_results[col]["MIN"] = results[f"MIN {d_col}"]
post_processed_results[col]["MAX"] = results[f"MAX {d_col}"]
post_processed_results[col]["STD"] = results[f"STD {d_col}"]

logger.info(
f"\nTest: {command}"
f"\nPerformance statistics for op: {op_name}"
f"\n{json.dumps(post_processed_results, indent=4)}"
)
return post_processed_results


def check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False):
expected_results = {}
failed = False
Expand Down
1 change: 1 addition & 0 deletions tests/nightly/tg/ccl/test_new_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost

NUM_BUFFERS = 8


def report_mismatches(golden, actual, max_printable=None):
Expand Down Expand Up @@ -64,6 +67,7 @@ def run_with_trace(
n_worker=None,
n_buffer=None,
num_iter=20,
warmup_iters=0,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
Expand Down Expand Up @@ -98,47 +102,66 @@ def run_with_trace(

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
if use_all_gather_async:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iter)

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
ttnn.synchronize_device(mesh_device)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
signpost("start")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
ttnn.synchronize_device(mesh_device)
signpost("stop")
profiler.end("all-gather-async-trace")
logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us")
time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration(
"all-gather-async-trace-warmup"
)
effective_iter = num_iter - warmup_iters
logger.info(f"Time taken e2e: {time_taken} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -160,6 +183,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
warmup_iters: int = 0,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
Expand Down Expand Up @@ -257,7 +281,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(

# create global semaphore handles
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS)
]
try:
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
Expand All @@ -274,11 +298,13 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
warmup_iters=warmup_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)

else:
signpost("start")
for i in range(num_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
Expand All @@ -288,7 +314,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i],
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS],
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
Expand All @@ -305,7 +331,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
topology=ttnn.Topology.Linear,
)
ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)

signpost("stop")
except Exception as e:
logger.error(f"Exception: {e}")
raise e
Expand Down
Loading

0 comments on commit f3d8fac

Please sign in to comment.