From 14a59d640caa073b43728ce6188c9dcf93253868 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Mon, 13 Jan 2025 18:13:18 +0000 Subject: [PATCH] Switch to PrettyTable --- .../nsys_jax/nsys_jax/analyses/communication.py | 17 ++++++----------- .github/container/nsys_jax/pyproject.toml | 1 + 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index b02e4af01..4f07124a1 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -8,6 +8,7 @@ load_profiler_data, ) from math import sqrt +from prettytable import PrettyTable from statistics import mean import pathlib from uncertainties import ufloat # type: ignore @@ -104,25 +105,19 @@ def format_bandwidth(data, collective): collective_types.add(collective) summary_data[collective] = df["DurHiddenMsToDurMs"].mean() - collective_width = max(len("Collective"), max(len(f"{collective}") for collective in collective_types)) - ratio_width = len("Mean HiddenToTotalMs") - - print() - print(f"{'Collective':<{collective_width}} | {'Mean HiddenToTotalMs':<{ratio_width}}") - print(f"{'-' * collective_width} | {'-' * ratio_width}") + table = PrettyTable() + table.field_names = ["Collective", "Mean HiddenToTotalMs"] for collective in collective_types: mean_value = summary_data[collective] - collective_str = str(collective[0]) - print(f"{collective_str:<{collective_width}} | {mean_value:>{ratio_width}}") + table.add_row([collective[0], mean_value]) + print(table) overall_hidden_ms_to_total_ms = ( steady_state.communication["ProjDurHiddenMs"].sum() / (steady_state.communication["ProjDurMs"] + steady_state.communication["ProjDurHiddenMs"]).sum() ) - - print() - print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms:>{ratio_width}}") + print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms}") if __name__ == "__main__": main() diff --git a/.github/container/nsys_jax/pyproject.toml b/.github/container/nsys_jax/pyproject.toml index 95bdffd4c..d0c79ad43 100644 --- a/.github/container/nsys_jax/pyproject.toml +++ b/.github/container/nsys_jax/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "pyarrow", "requests", # for install-protoc "uncertainties", # communication analysis recipe + "prettytable", ] requires-python = ">= 3.10"