Skip to content

Commit

Permalink
Merge branch 'main' into sbosisio/cuda-dl-base
Browse files Browse the repository at this point in the history
  • Loading branch information
Steboss authored Jan 16, 2025
2 parents bd066f1 + eb6d0d2 commit 564ec47
Show file tree
Hide file tree
Showing 12 changed files with 521 additions and 262 deletions.
3 changes: 2 additions & 1 deletion .github/container/nsys_jax/nsys_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .data_loaders import load_profiler_data
from .protobuf import xla_module_metadata
from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable
from .utils import remove_autotuning_detail, remove_child_ranges
from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges
from .visualization import create_flamegraph, display_flamegraph

__all__ = [
Expand All @@ -16,6 +16,7 @@
"calculate_collective_metrics",
"compile_protos",
"create_flamegraph",
"default_data_prefix",
"display_flamegraph",
"ensure_compiled_protos_are_importable",
"generate_compilation_statistics",
Expand Down
29 changes: 20 additions & 9 deletions .github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"from nsys_jax import (\n",
" align_profiler_data_timestamps,\n",
" apply_warmup_heuristics,\n",
" default_data_prefix,\n",
" display_flamegraph,\n",
" ensure_compiled_protos_are_importable,\n",
" generate_compilation_statistics,\n",
Expand All @@ -23,6 +24,18 @@
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a91f0e7-17da-4534-8ea9-29bcf3742567",
"metadata": {},
"outputs": [],
"source": [
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
"prefix = default_data_prefix()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -32,7 +45,7 @@
"source": [
"# Make sure that the .proto files under protos/ have been compiled to .py, and\n",
"# that those generated .py files are importable.]\n",
"compiled_dir = ensure_compiled_protos_are_importable()"
"compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)"
]
},
{
Expand All @@ -43,7 +56,7 @@
"outputs": [],
"source": [
"# Load the runtime profile data\n",
"all_data = load_profiler_data()\n",
"all_data = load_profiler_data(prefix)\n",
"# Remove some detail from the autotuner\n",
"all_data = remove_autotuning_detail(all_data)\n",
"# Align GPU timestamps across profiles collected by different Nsight Systems processes\n",
Expand Down Expand Up @@ -82,16 +95,14 @@
"source": [
"This data frame has a three-level index:\n",
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n",
"- `NumThunks`: the number of thunks executed inside this module execution\n",
"- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n",
"- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n",
"- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n",
"- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n",
"- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n",
"- `Process`: the global (across multiple nodes) index of the process\n",
"- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n",
Expand All @@ -117,13 +128,13 @@
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
"metadata": {},
"source": [
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n",
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
"- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n",
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
"\n",
Expand Down Expand Up @@ -299,7 +310,7 @@
"# Print out the largest entries adding up to at least this fraction of the total\n",
"threshold = 0.97\n",
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-3:.2f}s compilation time\")\n",
"for row in compile_summary[\n",
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
"].itertuples():\n",
Expand Down Expand Up @@ -378,7 +389,7 @@
" program_id, thunk_name = thunk_row.Index\n",
" # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n",
" # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n",
" hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n",
" hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n",
" thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n",
" lambda proto: instructions_and_frames(proto, thunk_name)\n",
" )\n",
Expand Down
162 changes: 140 additions & 22 deletions .github/container/nsys_jax/nsys_jax/analyses/communication.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
#!/usr/bin/env python
import argparse
import csv
from collections import defaultdict

from nsys_jax import (
align_profiler_data_timestamps,
apply_warmup_heuristics,
ensure_compiled_protos_are_importable,
load_profiler_data,
)
from math import sqrt
from prettytable import PrettyTable
import pathlib
from uncertainties import ufloat # type: ignore


def main():
parser = argparse.ArgumentParser(
description="Summarise communication in an nsys-jax report"
)
parser.add_argument("prefix", type=pathlib.Path)
args = parser.parse_args()
# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
ensure_compiled_protos_are_importable(prefix=args.prefix)
# Load the profiler data; the compilation part is needed for the warmup heuristics
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
# Align timestamps
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# TODO: make this pretty
# print(alignment_metadata)
# Partition the profile data into initialisation and steady-state running
_, steady_state = apply_warmup_heuristics(all_data)
assert len(steady_state.communication), (
"Communication summary was requested but no steady-state communication was "
"identified."
)
def process_communication_data(steady_state):
collective_types = set()
summary_data = defaultdict(dict)
for (collective, message_size), df in steady_state.communication.groupby(
Expand All @@ -52,7 +35,10 @@ def main():
summary_data[message_size][collective] = ufloat(
bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth))
)
collective_types = sorted(collective_types)
return sorted(collective_types), summary_data


def print_bandwidth_table(collective_types, summary_data):
collective_widths = {
collective: max(
len(collective),
Expand Down Expand Up @@ -96,5 +82,137 @@ def format_bandwidth(data, collective):
)


def process_hidden_ms_to_total_ms(steady_state):
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
return None, None

collective_types = set()
summary_data = defaultdict(dict)
for collective, df in steady_state.communication.groupby(["Collective"]):
collective_types.add(collective)
mean_dur_hidden_ms_to_total_ms = (
df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"])
).mean()
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
return collective_types, summary_data


def print_hidden_ms_to_total_ms_table(
collective_types, summary_data, overall_hidden_ms_to_total_ms
):
table = PrettyTable()
table.field_names = ["Collective", "Mean HiddenToTotalMs"]

for collective in collective_types:
mean_value = summary_data[collective]
table.add_row([collective[0], mean_value])

print(table)
print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms)


def calculate_overall_hidden_ms_to_total_ms(steady_state):
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
return None

overall_hidden_ms_to_total_ms = (
steady_state.communication["ProjDurHiddenMs"].sum()
/ (
steady_state.communication["ProjDurMs"]
+ steady_state.communication["ProjDurHiddenMs"]
).sum()
)
return overall_hidden_ms_to_total_ms


def write_to_csv(
collective_types,
bandwidth_summary,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
output_file,
):
with open(output_file, "w", newline="") as csvfile:
writer = csv.writer(csvfile)

# Write bandwidth table
writer.writerow(["Bandwidth Table"])
writer.writerow(["Size [B]"] + list(collective_types))
for message_size in sorted(bandwidth_summary.keys()):
row = [message_size]
for collective in collective_types:
if collective in bandwidth_summary[message_size]:
row.append(f"{bandwidth_summary[message_size][collective]:S}")
else:
row.append("-")
writer.writerow(row)

writer.writerow([]) # Empty row for separation

# Write hidden to total table if data is available
if hidden_to_total_summary is not None:
writer.writerow(["HiddenMs to TotalMs Table"])
writer.writerow(["Collective", "Mean HiddenToTotalMs"])
for collective in hidden_to_total_summary:
writer.writerow([collective[0], hidden_to_total_summary[collective]])

writer.writerow([]) # Empty row for separation

if overall_hidden_ms_to_total_ms is not None:
writer.writerow(
["Overall HiddenMs to TotalMs", overall_hidden_ms_to_total_ms]
)


def main():
parser = argparse.ArgumentParser(
description="Summarise communication in an nsys-jax report"
)
parser.add_argument("prefix", type=pathlib.Path)
args = parser.parse_args()

# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
ensure_compiled_protos_are_importable(prefix=args.prefix)
# Load the profiler data; the compilation part is needed for the warmup heuristics
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
# Align timestamps
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# TODO: make this pretty
# print(alignment_metadata)
# Partition the profile data into initialisation and steady-state running
_, steady_state = apply_warmup_heuristics(all_data)

assert len(steady_state.communication), (
"Communication summary was requested but no steady-state communication was "
"identified."
)

collective_types, bandwidth_summary = process_communication_data(steady_state)
print_bandwidth_table(collective_types, bandwidth_summary)

hidden_to_total_collective_types, hidden_to_total_summary = (
process_hidden_ms_to_total_ms(steady_state)
)
if hidden_to_total_summary is not None:
overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms(
steady_state
)
print_hidden_ms_to_total_ms_table(
hidden_to_total_collective_types,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
)

# Write all tables to a single CSV file
write_to_csv(
collective_types,
bandwidth_summary,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
"communication_summary.csv",
)


if __name__ == "__main__":
main()
34 changes: 19 additions & 15 deletions .github/container/nsys_jax/nsys_jax/analysis.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,8 @@ def get_message_size(
of the semantics. This implementation aims to follow the same conventions that NCCL
uses in its NVTX payloads and tests.
"""
return pd.Series(
xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
lambda proto: _get_message_size(proto, instruction)
),
index=[
"MessageSize",
"Collective",
"CollectiveSize",
"BandwidthCorrection",
"BusBandwidthCorrection",
],
return xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
lambda proto: _get_message_size(proto, instruction)
)


Expand All @@ -311,13 +302,26 @@ def calculate_collective_metrics(
comm_df = thunk_df[thunk_df["Communication"]].drop(columns=["Communication"])
if len(comm_df) == 0:
return comm_df

def body(tup):
idx, name = tup
return get_message_size(idx[0], name, prefix=prefix)

metrics_df = pd.DataFrame.from_records(
map(body, comm_df["Name"].items()),
columns=[
"MessageSize",
"Collective",
"CollectiveSize",
"BandwidthCorrection",
"BusBandwidthCorrection",
],
index=comm_df.index,
)
comm_df = pd.concat(
[
comm_df,
comm_df.apply(
lambda row: get_message_size(row.name[0], row.Name, prefix=prefix),
axis=1,
),
metrics_df,
],
axis=1,
)
Expand Down
Loading

0 comments on commit 564ec47

Please sign in to comment.