Skip to content

Commit

Permalink
#0: Revert "Minimal All Reduce for Llama shapes (#18217)"
Browse files Browse the repository at this point in the history
This reverts commit d41f968. This seems
to cause stateful errors on the runners. The root cause is related to
using a combination of our custom checkout action + a non-relative
symlink. Will blow away the _work on the runners and tell people to
rebase.
  • Loading branch information
rayraykay committed Mar 6, 2025
1 parent fbaa49a commit 9b24462
Show file tree
Hide file tree
Showing 24 changed files with 279 additions and 2,212 deletions.
120 changes: 0 additions & 120 deletions models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py

This file was deleted.

77 changes: 0 additions & 77 deletions models/perf/device_perf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@

import json
import time
import pandas as pd

from loguru import logger
from collections import defaultdict

from tt_metal.tools.profiler.common import clear_profiler_runtime_artifacts
from tt_metal.tools.profiler.process_model_log import (
get_latest_ops_log_filename,
post_process_ops_log,
run_device_profiler,
get_samples_per_s,
Expand Down Expand Up @@ -53,79 +49,6 @@ def run_device_perf(command, subdir, num_iterations, cols, batch_size, has_signp
return post_processed_results


# TODO: Move into process_model_log.py (#18698)
def post_process_ops_log_detailed(
output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False, detailed=False, warmup_iters=0
):
filename = get_latest_ops_log_filename(output_logs_subdir)
df = pd.read_csv(filename)

if has_signposts:
# there are explicit start and stop points in the model we want to measure between
markers = df[df["OP TYPE"] == "signpost"]["OP CODE"]
start = markers[markers == "start"].index[0]
stop = markers[markers == "stop"].index[0]
df = df.iloc[start + 1 : stop]
if op_name != "":
df = df[df["OP CODE"] == op_name]

if warmup_iters > 0:
df = df.iloc[warmup_iters:]

results = {}
for col in columns:
df_filtered = df[df[col] != "-"]
if sum_vals:
results[col] = df_filtered[col].astype(float).sum()
else:
results[col] = df_filtered[col].astype(float).to_numpy()

if detailed:
results[f"AVG {col}"] = df_filtered[col].astype(float).mean()
results[f"MIN {col}"] = df_filtered[col].astype(float).min()
results[f"MAX {col}"] = df_filtered[col].astype(float).max()
results[f"STD {col}"] = df_filtered[col].astype(float).std()

return results


def run_device_perf_detailed(command, subdir, cols, op_name="", has_signposts=False, warmup_iters=0):
duration_cols = [col + " DURATION [ns]" for col in cols]

clear_profiler_runtime_artifacts()

results = {}
for d_col in duration_cols:
results[f"AVG {d_col}"] = 0
results[f"MIN {d_col}"] = float("inf")
results[f"MAX {d_col}"] = -float("inf")
results[f"STD {d_col}"] = 0

run_device_profiler(command, subdir)
r = post_process_ops_log_detailed(
subdir, duration_cols, op_name=op_name, has_signposts=has_signposts, detailed=True, warmup_iters=warmup_iters
)
for d_col in duration_cols:
results[f"AVG {d_col}"] = r[f"AVG {d_col}"]
results[f"MIN {d_col}"] = r[f"MIN {d_col}"]
results[f"MAX {d_col}"] = r[f"MAX {d_col}"]
results[f"STD {d_col}"] = r[f"STD {d_col}"]

post_processed_results = defaultdict(dict)
for col, d_col in zip(cols, duration_cols):
post_processed_results[col]["AVG"] = results[f"AVG {d_col}"]
post_processed_results[col]["MIN"] = results[f"MIN {d_col}"]
post_processed_results[col]["MAX"] = results[f"MAX {d_col}"]
post_processed_results[col]["STD"] = results[f"STD {d_col}"]

logger.info(
f"\nTest: {command}"
f"\nPerformance statistics for op: {op_name}"
f"\n{json.dumps(post_processed_results, indent=4)}"
)
return post_processed_results


def check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False):
expected_results = {}
failed = False
Expand Down
1 change: 0 additions & 1 deletion tests/nightly/tg/ccl/test_new_all_reduce.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost

NUM_BUFFERS = 8


def report_mismatches(golden, actual, max_printable=None):
Expand Down Expand Up @@ -67,7 +64,6 @@ def run_with_trace(
n_worker=None,
n_buffer=None,
num_iter=20,
warmup_iters=0,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
Expand Down Expand Up @@ -102,66 +98,47 @@ def run_with_trace(

# Capture trace
logger.info("Capturing trace")

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
if use_all_gather_async:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iter)
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
ttnn.synchronize_device(mesh_device)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
signpost("start")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
ttnn.synchronize_device(mesh_device)
signpost("stop")
profiler.end("all-gather-async-trace")
time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration(
"all-gather-async-trace-warmup"
)
effective_iter = num_iter - warmup_iters
logger.info(f"Time taken e2e: {time_taken} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter * 1e6} us")
logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -183,7 +160,6 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
warmup_iters: int = 0,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
Expand Down Expand Up @@ -281,7 +257,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(

# create global semaphore handles
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS)
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
]
try:
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
Expand All @@ -298,13 +274,11 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
warmup_iters=warmup_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)

else:
signpost("start")
for i in range(num_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
Expand All @@ -314,7 +288,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS],
multi_device_global_semaphore=ccl_semaphore_handles[i],
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
Expand All @@ -331,7 +305,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
topology=ttnn.Topology.Linear,
)
ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)
signpost("stop")

except Exception as e:
logger.error(f"Exception: {e}")
raise e
Expand Down
Loading

0 comments on commit 9b24462

Please sign in to comment.