Skip to content

Commit

Permalink
Revert "Revert Fix Minimal All Reduce for Llama shapes" (#18755)
Browse files Browse the repository at this point in the history
Reverts #18731
Clang tidy errors and missing namespaces.
  • Loading branch information
ttmchiou authored Mar 6, 2025
1 parent 0c24e3f commit 604176b
Show file tree
Hide file tree
Showing 24 changed files with 279 additions and 2,212 deletions.
120 changes: 0 additions & 120 deletions models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py

This file was deleted.

77 changes: 0 additions & 77 deletions models/perf/device_perf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@

import json
import time
import pandas as pd

from loguru import logger
from collections import defaultdict

from tt_metal.tools.profiler.common import clear_profiler_runtime_artifacts
from tt_metal.tools.profiler.process_model_log import (
get_latest_ops_log_filename,
post_process_ops_log,
run_device_profiler,
get_samples_per_s,
Expand Down Expand Up @@ -53,79 +49,6 @@ def run_device_perf(command, subdir, num_iterations, cols, batch_size, has_signp
return post_processed_results


# TODO: Move into process_model_log.py (#18698)
def post_process_ops_log_detailed(
output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False, detailed=False, warmup_iters=0
):
filename = get_latest_ops_log_filename(output_logs_subdir)
df = pd.read_csv(filename)

if has_signposts:
# there are explicit start and stop points in the model we want to measure between
markers = df[df["OP TYPE"] == "signpost"]["OP CODE"]
start = markers[markers == "start"].index[0]
stop = markers[markers == "stop"].index[0]
df = df.iloc[start + 1 : stop]
if op_name != "":
df = df[df["OP CODE"] == op_name]

if warmup_iters > 0:
df = df.iloc[warmup_iters:]

results = {}
for col in columns:
df_filtered = df[df[col] != "-"]
if sum_vals:
results[col] = df_filtered[col].astype(float).sum()
else:
results[col] = df_filtered[col].astype(float).to_numpy()

if detailed:
results[f"AVG {col}"] = df_filtered[col].astype(float).mean()
results[f"MIN {col}"] = df_filtered[col].astype(float).min()
results[f"MAX {col}"] = df_filtered[col].astype(float).max()
results[f"STD {col}"] = df_filtered[col].astype(float).std()

return results


def run_device_perf_detailed(command, subdir, cols, op_name="", has_signposts=False, warmup_iters=0):
duration_cols = [col + " DURATION [ns]" for col in cols]

clear_profiler_runtime_artifacts()

results = {}
for d_col in duration_cols:
results[f"AVG {d_col}"] = 0
results[f"MIN {d_col}"] = float("inf")
results[f"MAX {d_col}"] = -float("inf")
results[f"STD {d_col}"] = 0

run_device_profiler(command, subdir)
r = post_process_ops_log_detailed(
subdir, duration_cols, op_name=op_name, has_signposts=has_signposts, detailed=True, warmup_iters=warmup_iters
)
for d_col in duration_cols:
results[f"AVG {d_col}"] = r[f"AVG {d_col}"]
results[f"MIN {d_col}"] = r[f"MIN {d_col}"]
results[f"MAX {d_col}"] = r[f"MAX {d_col}"]
results[f"STD {d_col}"] = r[f"STD {d_col}"]

post_processed_results = defaultdict(dict)
for col, d_col in zip(cols, duration_cols):
post_processed_results[col]["AVG"] = results[f"AVG {d_col}"]
post_processed_results[col]["MIN"] = results[f"MIN {d_col}"]
post_processed_results[col]["MAX"] = results[f"MAX {d_col}"]
post_processed_results[col]["STD"] = results[f"STD {d_col}"]

logger.info(
f"\nTest: {command}"
f"\nPerformance statistics for op: {op_name}"
f"\n{json.dumps(post_processed_results, indent=4)}"
)
return post_processed_results


def check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False):
expected_results = {}
failed = False
Expand Down
1 change: 0 additions & 1 deletion tests/nightly/tg/ccl/test_new_all_reduce.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost

NUM_BUFFERS = 8


def report_mismatches(golden, actual, max_printable=None):
Expand Down Expand Up @@ -67,7 +64,6 @@ def run_with_trace(
n_worker=None,
n_buffer=None,
num_iter=20,
warmup_iters=0,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
Expand Down Expand Up @@ -102,66 +98,47 @@ def run_with_trace(

# Capture trace
logger.info("Capturing trace")

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
if use_all_gather_async:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iter)
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
ttnn.synchronize_device(mesh_device)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
signpost("start")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
ttnn.synchronize_device(mesh_device)
signpost("stop")
profiler.end("all-gather-async-trace")
time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration(
"all-gather-async-trace-warmup"
)
effective_iter = num_iter - warmup_iters
logger.info(f"Time taken e2e: {time_taken} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter * 1e6} us")
logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -183,7 +160,6 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
warmup_iters: int = 0,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
Expand Down Expand Up @@ -281,7 +257,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(

# create global semaphore handles
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS)
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
]
try:
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
Expand All @@ -298,13 +274,11 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
warmup_iters=warmup_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)

else:
signpost("start")
for i in range(num_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
Expand All @@ -314,7 +288,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS],
multi_device_global_semaphore=ccl_semaphore_handles[i],
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
Expand All @@ -331,7 +305,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
topology=ttnn.Topology.Linear,
)
ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)
signpost("stop")

except Exception as e:
logger.error(f"Exception: {e}")
raise e
Expand Down
Loading

0 comments on commit 604176b

Please sign in to comment.