Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: optimise data loading and .zip creation #1193

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion .github/container/nsys_jax/nsys_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .data_loaders import load_profiler_data
from .protobuf import xla_module_metadata
from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable
from .utils import remove_autotuning_detail, remove_child_ranges
from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges
from .visualization import create_flamegraph, display_flamegraph

__all__ = [
Expand All @@ -16,6 +16,7 @@
"calculate_collective_metrics",
"compile_protos",
"create_flamegraph",
"default_data_prefix",
"display_flamegraph",
"ensure_compiled_protos_are_importable",
"generate_compilation_statistics",
Expand Down
29 changes: 20 additions & 9 deletions .github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"from nsys_jax import (\n",
" align_profiler_data_timestamps,\n",
" apply_warmup_heuristics,\n",
" default_data_prefix,\n",
" display_flamegraph,\n",
" ensure_compiled_protos_are_importable,\n",
" generate_compilation_statistics,\n",
Expand All @@ -23,6 +24,18 @@
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a91f0e7-17da-4534-8ea9-29bcf3742567",
"metadata": {},
"outputs": [],
"source": [
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
"prefix = default_data_prefix()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -32,7 +45,7 @@
"source": [
"# Make sure that the .proto files under protos/ have been compiled to .py, and\n",
"# that those generated .py files are importable.]\n",
"compiled_dir = ensure_compiled_protos_are_importable()"
"compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)"
]
},
{
Expand All @@ -43,7 +56,7 @@
"outputs": [],
"source": [
"# Load the runtime profile data\n",
"all_data = load_profiler_data()\n",
"all_data = load_profiler_data(prefix)\n",
"# Remove some detail from the autotuner\n",
"all_data = remove_autotuning_detail(all_data)\n",
"# Align GPU timestamps across profiles collected by different Nsight Systems processes\n",
Expand Down Expand Up @@ -82,16 +95,14 @@
"source": [
"This data frame has a three-level index:\n",
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n",
"- `NumThunks`: the number of thunks executed inside this module execution\n",
"- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n",
"- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n",
"- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n",
"- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n",
"- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n",
"- `Process`: the global (across multiple nodes) index of the process\n",
"- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n",
Expand All @@ -117,13 +128,13 @@
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
"metadata": {},
"source": [
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n",
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
"- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n",
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
"\n",
Expand Down Expand Up @@ -299,7 +310,7 @@
"# Print out the largest entries adding up to at least this fraction of the total\n",
"threshold = 0.97\n",
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-3:.2f}s compilation time\")\n",
"for row in compile_summary[\n",
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
"].itertuples():\n",
Expand Down Expand Up @@ -378,7 +389,7 @@
" program_id, thunk_name = thunk_row.Index\n",
" # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n",
" # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n",
" hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n",
" hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n",
" thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n",
" lambda proto: instructions_and_frames(proto, thunk_name)\n",
" )\n",
Expand Down
34 changes: 19 additions & 15 deletions .github/container/nsys_jax/nsys_jax/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,8 @@ def get_message_size(
of the semantics. This implementation aims to follow the same conventions that NCCL
uses in its NVTX payloads and tests.
"""
return pd.Series(
xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
lambda proto: _get_message_size(proto, instruction)
),
index=[
"MessageSize",
"Collective",
"CollectiveSize",
"BandwidthCorrection",
"BusBandwidthCorrection",
],
return xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
lambda proto: _get_message_size(proto, instruction)
)


Expand All @@ -311,13 +302,26 @@ def calculate_collective_metrics(
comm_df = thunk_df[thunk_df["Communication"]].drop(columns=["Communication"])
if len(comm_df) == 0:
return comm_df

def body(tup):
idx, name = tup
return get_message_size(idx[0], name, prefix=prefix)

metrics_df = pd.DataFrame.from_records(
map(body, comm_df["Name"].items()),
columns=[
"MessageSize",
"Collective",
"CollectiveSize",
"BandwidthCorrection",
"BusBandwidthCorrection",
],
index=comm_df.index,
)
comm_df = pd.concat(
[
comm_df,
comm_df.apply(
lambda row: get_message_size(row.name[0], row.Name, prefix=prefix),
axis=1,
),
metrics_df,
],
axis=1,
)
Expand Down
Loading
Loading