Skip to content

Commit

Permalink
Add review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sfvaroglu committed Jan 15, 2025
1 parent af2aee6 commit 08364a6
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 27 deletions.
142 changes: 118 additions & 24 deletions .github/container/nsys_jax/nsys_jax/analyses/communication.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python
import argparse
import csv
from collections import defaultdict

from nsys_jax import (
align_profiler_data_timestamps,
apply_warmup_heuristics,
Expand All @@ -13,27 +15,7 @@
from uncertainties import ufloat # type: ignore


def main():
parser = argparse.ArgumentParser(
description="Summarise communication in an nsys-jax report"
)
parser.add_argument("prefix", type=pathlib.Path)
args = parser.parse_args()
# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
ensure_compiled_protos_are_importable(prefix=args.prefix)
# Load the profiler data; the compilation part is needed for the warmup heuristics
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
# Align timestamps
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# TODO: make this pretty
# print(alignment_metadata)
# Partition the profile data into initialisation and steady-state running
_, steady_state = apply_warmup_heuristics(all_data)
assert len(steady_state.communication), (
"Communication summary was requested but no steady-state communication was "
"identified."
)
def process_communication_data(steady_state):
collective_types = set()
summary_data = defaultdict(dict)
for (collective, message_size), df in steady_state.communication.groupby(
Expand All @@ -53,7 +35,10 @@ def main():
summary_data[message_size][collective] = ufloat(
bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth))
)
collective_types = sorted(collective_types)
return sorted(collective_types), summary_data


def print_bandwidth_table(collective_types, summary_data):
collective_widths = {
collective: max(
len(collective),
Expand Down Expand Up @@ -96,19 +81,39 @@ def format_bandwidth(data, collective):
)
)


def process_hidden_ms_to_total_ms(steady_state):
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
return None, None

collective_types = set()
summary_data = defaultdict(dict)
for collective, df in steady_state.communication.groupby(["Collective"]):
collective_types.add(collective)
summary_data[collective] = df["DurHiddenMsToDurMs"].mean()
mean_dur_hidden_ms_to_total_ms = (
df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"])
).mean()
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
return collective_types, summary_data


def print_hidden_ms_to_total_ms_table(
collective_types, summary_data, overall_hidden_ms_to_total_ms
):
table = PrettyTable()
table.field_names = ["Collective", "Mean HiddenToTotalMs"]

for collective in collective_types:
mean_value = summary_data[collective]
table.add_row([collective[0], mean_value])

print(table)
print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms)


def calculate_overall_hidden_ms_to_total_ms(steady_state):
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
return None

overall_hidden_ms_to_total_ms = (
steady_state.communication["ProjDurHiddenMs"].sum()
Expand All @@ -117,7 +122,96 @@ def format_bandwidth(data, collective):
+ steady_state.communication["ProjDurHiddenMs"]
).sum()
)
print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms}")
return overall_hidden_ms_to_total_ms


def write_to_csv(
collective_types,
bandwidth_summary,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
output_file,
):
with open(output_file, "w", newline="") as csvfile:
writer = csv.writer(csvfile)

# Write bandwidth table
writer.writerow(["Bandwidth Table"])
writer.writerow(["Size [B]"] + list(collective_types))
for message_size in sorted(bandwidth_summary.keys()):
row = [message_size]
for collective in collective_types:
if collective in bandwidth_summary[message_size]:
row.append(f"{bandwidth_summary[message_size][collective]:S}")
else:
row.append("-")
writer.writerow(row)

writer.writerow([]) # Empty row for separation

# Write hidden to total table if data is available
if hidden_to_total_summary is not None:
writer.writerow(["HiddenMs to TotalMs Table"])
writer.writerow(["Collective", "Mean HiddenToTotalMs"])
for collective in hidden_to_total_summary:
writer.writerow([collective[0], hidden_to_total_summary[collective]])

writer.writerow([]) # Empty row for separation

if overall_hidden_ms_to_total_ms is not None:
writer.writerow(
["Overall HiddenMs to TotalMs", overall_hidden_ms_to_total_ms]
)


def main():
parser = argparse.ArgumentParser(
description="Summarise communication in an nsys-jax report"
)
parser.add_argument("prefix", type=pathlib.Path)
args = parser.parse_args()

# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
ensure_compiled_protos_are_importable(prefix=args.prefix)
# Load the profiler data; the compilation part is needed for the warmup heuristics
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
# Align timestamps
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# TODO: make this pretty
# print(alignment_metadata)
# Partition the profile data into initialisation and steady-state running
_, steady_state = apply_warmup_heuristics(all_data)

assert len(steady_state.communication), (
"Communication summary was requested but no steady-state communication was "
"identified."
)

collective_types, bandwidth_summary = process_communication_data(steady_state)
print_bandwidth_table(collective_types, bandwidth_summary)

hidden_to_total_collective_types, hidden_to_total_summary = (
process_hidden_ms_to_total_ms(steady_state)
)
if hidden_to_total_summary is not None:
overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms(
steady_state
)
print_hidden_ms_to_total_ms_table(
hidden_to_total_collective_types,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
)

# Write all tables to a single CSV file
write_to_csv(
collective_types,
bandwidth_summary,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
"communication_summary.csv",
)


if __name__ == "__main__":
Expand Down
3 changes: 0 additions & 3 deletions .github/container/nsys_jax/nsys_jax/analysis.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,6 @@ def calculate_collective_metrics(
comm_df["BusBandwidthGBPerSec"] = (
comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"]
)
comm_df["DurHiddenMsToDurMs"] = comm_df["ProjDurHiddenMs"] / (
comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"]
)
return comm_df.drop(columns=["BandwidthCorrection", "BusBandwidthCorrection"])


Expand Down

0 comments on commit 08364a6

Please sign in to comment.