Skip to content

Commit

Permalink
Add ratio of hidden communication time to total communication time
Browse files Browse the repository at this point in the history
  • Loading branch information
sfvaroglu committed Jan 10, 2025
1 parent 26451c0 commit d7d5449
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
28 changes: 28 additions & 0 deletions .github/container/nsys_jax/nsys_jax/analyses/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
load_profiler_data,
)
from math import sqrt
from statistics import mean
import pathlib
from uncertainties import ufloat # type: ignore

Expand Down Expand Up @@ -95,6 +96,33 @@ def format_bandwidth(data, collective):
)
)

collective_types = set()
summary_data = defaultdict(dict)
for collective, df in steady_state.communication.groupby(
["Collective"]
):
collective_types.add(collective)
summary_data[collective] = df["DurHiddenMsToDurMs"].mean()

collective_width = max(len("Collective"), max(len(f"{collective}") for collective in collective_types))
ratio_width = len("Mean HiddenToTotalMs")

print()
print(f"{'Collective':<{collective_width}} | {'Mean HiddenToTotalMs':<{ratio_width}}")
print(f"{'-' * collective_width} | {'-' * ratio_width}")

for collective in collective_types:
mean_value = summary_data[collective]
collective_str = str(collective[0])
print(f"{collective_str:<{collective_width}} | {mean_value:>{ratio_width}}")

overall_hidden_ms_to_total_ms = (
steady_state.communication["ProjDurHiddenMs"].sum() /
(steady_state.communication["ProjDurMs"] + steady_state.communication["ProjDurHiddenMs"]).sum()
)

print()
print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms:>{ratio_width}}")

if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions .github/container/nsys_jax/nsys_jax/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def calculate_collective_metrics(
comm_df["BusBandwidthGBPerSec"] = (
comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"]
)
comm_df["DurHiddenMsToDurMs"] = (
comm_df["ProjDurHiddenMs"] / (comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"])
)
return comm_df.drop(columns=["BandwidthCorrection", "BusBandwidthCorrection"])


Expand Down

0 comments on commit d7d5449

Please sign in to comment.