diff --git a/.github/container/nsys_jax/nsys_jax/__init__.py b/.github/container/nsys_jax/nsys_jax/__init__.py index e89395d8a..93fde24ee 100644 --- a/.github/container/nsys_jax/nsys_jax/__init__.py +++ b/.github/container/nsys_jax/nsys_jax/__init__.py @@ -7,7 +7,7 @@ from .data_loaders import load_profiler_data from .protobuf import xla_module_metadata from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable -from .utils import remove_autotuning_detail, remove_child_ranges +from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges from .visualization import create_flamegraph, display_flamegraph __all__ = [ @@ -16,6 +16,7 @@ "calculate_collective_metrics", "compile_protos", "create_flamegraph", + "default_data_prefix", "display_flamegraph", "ensure_compiled_protos_are_importable", "generate_compilation_statistics", diff --git a/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb index 8159f2d28..14339d2d6 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb +++ b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb @@ -12,6 +12,7 @@ "from nsys_jax import (\n", " align_profiler_data_timestamps,\n", " apply_warmup_heuristics,\n", + " default_data_prefix,\n", " display_flamegraph,\n", " ensure_compiled_protos_are_importable,\n", " generate_compilation_statistics,\n", @@ -23,6 +24,18 @@ "import numpy as np" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a91f0e7-17da-4534-8ea9-29bcf3742567", + "metadata": {}, + "outputs": [], + "source": [ + "# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n", + "# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n", + "prefix = default_data_prefix()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -32,7 +45,7 @@ "source": [ "# Make sure that the .proto files under protos/ have been compiled to .py, and\n", "# that those generated .py files are importable.]\n", - "compiled_dir = ensure_compiled_protos_are_importable()" + "compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)" ] }, { @@ -43,7 +56,7 @@ "outputs": [], "source": [ "# Load the runtime profile data\n", - "all_data = load_profiler_data()\n", + "all_data = load_profiler_data(prefix)\n", "# Remove some detail from the autotuner\n", "all_data = remove_autotuning_detail(all_data)\n", "# Align GPU timestamps across profiles collected by different Nsight Systems processes\n", @@ -82,7 +95,7 @@ "source": [ "This data frame has a three-level index:\n", "- `ProgramId` is an integer ID that uniquely identifies the XLA module\n", - "- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n", + "- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n", "- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n", "\n", "The columns are as follows:\n", @@ -90,8 +103,6 @@ "- `NumThunks`: the number of thunks executed inside this module execution\n", "- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n", "- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n", - "- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n", - "- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n", "- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n", "- `Process`: the global (across multiple nodes) index of the process\n", "- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n", @@ -117,13 +128,13 @@ "id": "7727d800-13d3-4505-89e8-80a5fed63512", "metadata": {}, "source": [ - "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n", + "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n", "The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n", "Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n", "\n", "The columns are as follows:\n", "- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n", - "- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n", + "- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n", "- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n", "- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n", "\n", @@ -299,7 +310,7 @@ "# Print out the largest entries adding up to at least this fraction of the total\n", "threshold = 0.97\n", "compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n", - "print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n", + "print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-3:.2f}s compilation time\")\n", "for row in compile_summary[\n", " compile_summary[\"FracNonChild\"].cumsum() <= threshold\n", "].itertuples():\n", @@ -378,7 +389,7 @@ " program_id, thunk_name = thunk_row.Index\n", " # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n", " # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n", - " hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n", + " hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n", " thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n", " lambda proto: instructions_and_frames(proto, thunk_name)\n", " )\n", diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py index c4e37fdf9..12c4e0fb7 100644 --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -286,17 +286,8 @@ def get_message_size( of the semantics. This implementation aims to follow the same conventions that NCCL uses in its NVTX payloads and tests. """ - return pd.Series( - xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result( - lambda proto: _get_message_size(proto, instruction) - ), - index=[ - "MessageSize", - "Collective", - "CollectiveSize", - "BandwidthCorrection", - "BusBandwidthCorrection", - ], + return xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result( + lambda proto: _get_message_size(proto, instruction) ) @@ -311,13 +302,26 @@ def calculate_collective_metrics( comm_df = thunk_df[thunk_df["Communication"]].drop(columns=["Communication"]) if len(comm_df) == 0: return comm_df + + def body(tup): + idx, name = tup + return get_message_size(idx[0], name, prefix=prefix) + + metrics_df = pd.DataFrame.from_records( + map(body, comm_df["Name"].items()), + columns=[ + "MessageSize", + "Collective", + "CollectiveSize", + "BandwidthCorrection", + "BusBandwidthCorrection", + ], + index=comm_df.index, + ) comm_df = pd.concat( [ comm_df, - comm_df.apply( - lambda row: get_message_size(row.name[0], row.Name, prefix=prefix), - axis=1, - ), + metrics_df, ], axis=1, ) diff --git a/.github/container/nsys_jax/nsys_jax/data_loaders.py b/.github/container/nsys_jax/nsys_jax/data_loaders.py index 608f5a5d6..57e52403d 100644 --- a/.github/container/nsys_jax/nsys_jax/data_loaders.py +++ b/.github/container/nsys_jax/nsys_jax/data_loaders.py @@ -35,72 +35,83 @@ def _is_communication( raise +def _find_overlapped(start: pd.Series, end: pd.Series) -> pd.Index: + """ + Given a start/end series representing a set of possibly-overlapping ranges + start = [s0, s1, ...] + end = [e0, e1, ...] + which might overlap like: + [s0 e0] + [s1 e1] + [s2 e2] + [s3 e3] + [s4 e4] + return the index values of ranges that overlap with other ranges, i.e. in the + example (1, 2, 3) but not (0, 4). + """ + n = len(start) + assert n == len(end), (n, len(end)) + # Earliest start value of a later row, +inf for the last entry (which doesn't have any later rows) + next_start = np.full((n,), float("+inf")) + next_start[:-1] = start[::-1].cummin()[-2::-1] # reverse + drop 0th + # Latest end value of an earlier thunk, -inf for the first entry + prev_end = np.full((n,), float("-inf")) + prev_end[1:] = end.cummax()[:-1] + # Find rows that have overlap + mask = (next_start < end) | (prev_end > start) + return mask[mask].index + + def _calculate_overlap(thunk_df: pd.DataFrame) -> pd.DataFrame: thunk_df["ProjDurHiddenMs"] = 0.0 # For convenience when calculating unhidden comms thunk_df["ProjEndMs"] = thunk_df["ProjStartMs"] + thunk_df["ProjDurMs"] - for (program_id, device), module_exec_df in thunk_df.groupby( - ["ProgramId", "Device"] - ): - # If *everything* is serialised, communication *and* computation, there's no - # work to be done - serial_mask = ( - module_exec_df["ProjStartMs"].array[1:] - >= module_exec_df["ProjEndMs"].array[:-1] + for _, module_exec_df in thunk_df.groupby(["ProgramId", "Device"]): + # Identify overlap points that need more investigation + overlap_ids = _find_overlapped( + module_exec_df["ProjStartMs"], module_exec_df["ProjEndMs"] ) - if serial_mask.all(): + if len(overlap_ids) == 0: continue - # At least expect all computation to be serialized - comm_df = module_exec_df[module_exec_df["Communication"]] - compute_df = module_exec_df[~module_exec_df["Communication"]] - serial_mask = ( - compute_df["ProjStartMs"].array[1:] >= compute_df["ProjEndMs"].array[:-1] - ) - assert serial_mask.all(), ( - f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}" - ) - # Update the projected duration of each communication kernel to only - # include the non-overlapped time - for comm_thunk in comm_df.itertuples(): - # This is a range annotating a communication operation, i.e. NCCL kernel - # That kernel was active from comm_thunk.ProjStartMs until comm_thunk.ProjEndMs - # but during that time then other computation was going on. We want to - # find how much of the time did not overlap with other computation. - overlap_df = compute_df.loc[ + # All overlapping thunks in `module_exec_df` + overlap_df = module_exec_df.loc[ + overlap_ids, ("Communication", "ProjStartMs", "ProjEndMs") + ] + # Just the subset that are communication + compute_df = overlap_df[~overlap_df["Communication"]] + # Narrow down to overlapped communication thunks + for comm_thunk in overlap_df.loc[ + overlap_df["Communication"], ("ProjStartMs", "ProjEndMs") + ].itertuples(): + local_df = compute_df.loc[ (compute_df["ProjEndMs"] > comm_thunk.ProjStartMs) - & (compute_df["ProjStartMs"] < comm_thunk.ProjEndMs), - ("ProjStartMs", "ProjEndMs"), + & (compute_df["ProjStartMs"] < comm_thunk.ProjEndMs) ] compute_time = np.sum( - np.minimum(overlap_df["ProjEndMs"], comm_thunk.ProjEndMs) - - np.maximum(overlap_df["ProjStartMs"], comm_thunk.ProjStartMs) + np.minimum(local_df["ProjEndMs"], comm_thunk.ProjEndMs) + - np.maximum(local_df["ProjStartMs"], comm_thunk.ProjStartMs) ) # Update the projected duration of communication kernels to just be the # time that is not hidden. thunk_df.loc[comm_thunk.Index, "ProjDurMs"] -= compute_time thunk_df.loc[comm_thunk.Index, "ProjDurHiddenMs"] = compute_time - return thunk_df.drop(columns=["ProjEndMs"]) def _classify_comms(thunk_df: pd.DataFrame, prefix: pathlib.Path) -> pd.DataFrame: # Classify each thunk as either communication or computation, as we only # want to attribute non-overlapped communication time to those operations. - # Use HloInstructionProto.channel_id as a proxy for whether an operation is - # communication. - def is_communication(row): - assert thunk_df.index.names == [ - "ProgramId", - "ProgramExecution", - "ThunkIndex", - "Device", - ] + assert thunk_df.index.names[0] == "ProgramId" + + def is_communication(tup): + idx, name = tup return _is_communication( - program_id=row.name[0], prefix=prefix, instruction_name=row["Name"] + program_id=idx[0], prefix=prefix, instruction_name=name ) - thunk_df["Communication"] = thunk_df.loc[:, ("Name",)].apply( - is_communication, axis=1 + thunk_df["Communication"] = pd.Series( + data=map(is_communication, thunk_df["Name"].items()), + index=thunk_df.index, ) return _calculate_overlap(thunk_df) @@ -108,14 +119,19 @@ def is_communication(row): compile_prefix = "XlaCompile:#module=" +def _load_parquet_file(file: pathlib.Path) -> pd.DataFrame: + # Separate function to make profiles of this Python code easier to read + return pd.read_parquet(file) + + def _load_nvtx_gpu_proj_trace_single( prefix: pathlib.Path, file: pathlib.Path, meta_file: pathlib.Path, frames: set[str], -): +) -> dict[str, pd.DataFrame]: # Load the thread metadata used to map module/thunk executions to global device IDs - meta_df = pd.read_parquet(meta_file) + meta_df = _load_parquet_file(meta_file) # Match XLA's launcher thread name. These threads launch work if >1 GPU is being # driven by the process. device_by_pid_tid = ( @@ -127,7 +143,7 @@ def _load_nvtx_gpu_proj_trace_single( .astype(np.int32) ) # Load input data; rename some columns for convenience with `.itertuples()`; use RangeId as the index - df = pd.read_parquet(file).drop(columns=["Rank"]) + df = _load_parquet_file(file).drop(columns=["Rank"]) # Alternative trace.parquet format alt_rename_map = { "Text": "Name", @@ -165,8 +181,8 @@ def _load_nvtx_gpu_proj_trace_single( except ValueError: print( "A duplicate key related error may indicate that you are using " - "Nsight Systems 2024.5 and have CUDA graphs enabled; as noted on " - "https://github.com/NVIDIA/JAX-Toolbox/blob/main/docs/profiling.md " + "Nsight Systems 2024.5 or 2024.6 and have CUDA graphs enabled; as noted " + "on https://github.com/NVIDIA/JAX-Toolbox/blob/main/docs/profiling.md " "you may want to disable CUDA graphs by adding " "--xla_gpu_enable_command_buffer= to the XLA_FLAGS environment " "variable." @@ -355,14 +371,17 @@ def clean_data_frame(d): output = {} if "thunk" in frames: - # At this point there should be no need to look beyond the rows for individual thunks + the protobuf data, and we can further clean up the data + # At this point there should be no need to look beyond the rows for individual + # thunks + the protobuf data, and we can further clean up the data. thunk_df = clean_data_frame(df[all_thunks]) - thunk_df["Name"] = thunk_df["Name"].replace( - to_replace=f"^{tsl_prefix}Thunk:#(?:name=(.*?),|)hlo_op=([a-z0-9._-]+)#$", - value=r"\2", + thunk_df["Name"] = thunk_df["Name"].str.replace( + pat=f"^{tsl_prefix}Thunk:#(?:name=.*?,|)hlo_op=([a-z0-9._-]+)#$", + n=1, + repl=lambda m: m.group(1), regex=True, ) # Add an index of the thunk within the module + # TODO: the ordering is potentially inconsistent across module executions in case of multiple streams/overlap thunk_df["ThunkIndex"] = thunk_df.groupby( ["ProgramId", "ProgramExecution", "Device"] ).cumcount() @@ -375,7 +394,6 @@ def clean_data_frame(d): ).cumcount() # Classify thunks as communication/computation and save to output - # TODO: instead of using this sorting here, use a sort bucketed by execution time to speed up overlap detection? output["thunk"] = _classify_comms( thunk_df.set_index( ["ProgramId", "ProgramExecution", "ThunkIndex", "Device"] @@ -422,22 +440,28 @@ def _load_nvtx_gpu_proj_trace( filenames = [path] meta_filenames = [meta_path] - tmp = defaultdict(list) - with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool: - for single_trace in pool.starmap( - _load_nvtx_gpu_proj_trace_single, - zip( - itertools.repeat(prefix), - filenames, - meta_filenames, - itertools.repeat(frames), - ), - ): - for k, v in single_trace.items(): - tmp[k].append(v) - output = {} - for k, v in tmp.items(): - output[k] = pd.concat(v, verify_integrity=True).sort_index() + if len(filenames) > 1: + tmp = defaultdict(list) + with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool: + for single_trace in pool.starmap( + _load_nvtx_gpu_proj_trace_single, + zip( + itertools.repeat(prefix), + filenames, + meta_filenames, + itertools.repeat(frames), + ), + ): + for k, v in single_trace.items(): + tmp[k].append(v) + output = {} + for k, v in tmp.items(): + output[k] = pd.concat(v, verify_integrity=True).sort_index() + else: + output = _load_nvtx_gpu_proj_trace_single( + prefix, filenames[0], meta_filenames[0], frames + ) + output = {k: v.sort_index() for k, v in output.items()} return output @@ -568,16 +592,30 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame: return compile_df[tsl_mask] +def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame: + # `file` follows one of two patterns, depending on whether we are loading the + # results from a single profile or from multiple merged profiles: + # - nsys-jax: /path/to/report_nvtx_pushpop_trace.parquet + # - nsys-jax-combine: /path/to/report_nvtx_pushpop_trace.parquet/rank5 + new_name = "report_nvtx_pushpop_trace.parquet" + if file.name == new_name or file.parent.name == new_name: + # New mode; the .csv to .parquet conversion is done in nsys-jax + return pd.read_parquet(file) + else: + + def keep_column(name): + return name not in {"PID", "Lvl", "NameTree"} + + return pd.read_csv( + lzma.open(file, "rt", newline=""), + dtype={"RangeId": np.int32}, + index_col="RangeId", + usecols=keep_column, + ) + + def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame: - def keep_column(name): - return name not in {"PID", "Lvl", "NameTree"} - - compile_df = pd.read_csv( - lzma.open(name, "rt", newline=""), - dtype={"RangeId": np.int32}, - index_col="RangeId", - usecols=keep_column, - ) + compile_df = _read_nvtx_pushpop_trace_file(name) compile_df["StartMs"] = 1e-6 * compile_df.pop("Start (ns)") compile_df["EndMs"] = 1e-6 * compile_df.pop("End (ns)") compile_df["DurMs"] = 1e-6 * compile_df.pop("Duration (ns)") @@ -596,23 +634,25 @@ def keep_column(name): # Because the ProgramId and ProgramName ranges provide the same information, # remove those fields from the compilation range names. def remove_program_id_and_name(row): - row.Name = ( + return ( row.Name.removeprefix("TSL:") .replace(f",program_id={row.ProgramId}", "") .replace(f",module={row.ProgramName}", "") .replace(f":#module={row.ProgramName}#", "") ) - return row - return ( - compile_df.drop(columns=["EndMs"]) - .astype({"ProgramId": np.int32}) - .transform(remove_program_id_and_name, axis="columns") - ) + compile_df = compile_df.drop(columns=["EndMs"]).astype({"ProgramId": np.int32}) + if len(compile_df): + compile_df["Name"] = compile_df.apply( + remove_program_id_and_name, axis="columns" + ) + return compile_df def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame: - path = prefix / "report_nvtx_pushpop_trace.csv.xz" + new_path = prefix / "report_nvtx_pushpop_trace.parquet" + legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz" + path = new_path if new_path.exists() else legacy_path if path.is_dir(): # We're looking at the output of nsys-jax-combine filenames = sorted(path.iterdir()) @@ -622,12 +662,16 @@ def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataF filenames = [path] keys = [prefix.name] - with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool: - return pd.concat( - pool.map(_load_nvtx_pushpop_trace_single, filenames), - keys=keys, - names=["ProfileName", "RangeId"], - ) + if len(filenames) > 1: + with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool: + chunks = pool.map(_load_nvtx_pushpop_trace_single, filenames) + else: + chunks = [_load_nvtx_pushpop_trace_single(filenames[0])] + return pd.concat( + chunks, + keys=keys, + names=["ProfileName", "RangeId"], + ) def load_profiler_data( diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index 136843999..6154d36f8 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from glob import glob, iglob import lzma +import numpy as np import os import os.path as osp import pandas as pd # type: ignore @@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): if osp.isdir(full_path) or not osp.exists(full_path): continue output_queue.put((ofile, full_path, COMPRESS_NONE)) - print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s") + print( + f"{archive_name}: recipe post-processing finished in {time.time() - start:.2f}s" + ) def compress_and_archive(prefix, file, output_queue): """ @@ -401,9 +404,29 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): ], check=True, ) - for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir): - compress_and_archive(tmp_dir, ofile, output_queue) - print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s") + output_path = osp.join(tmp_dir, f"report_{report}.csv") + + # TODO: avoid the .csv indirection + def keep_column(name): + return name not in {"PID", "Lvl", "NameTree"} + + try: + df = pd.read_csv( + output_path, + dtype={"RangeId": np.int32}, + index_col="RangeId", + usecols=keep_column, + ) + parquet_name = f"report_{report}.parquet" + parquet_path = osp.join(tmp_dir, parquet_name) + df.to_parquet(parquet_path) + output_queue.put((parquet_name, parquet_path, COMPRESS_NONE)) + except pd.errors.EmptyDataError: + # If there's no data, don't write a file to the output at all + pass + print( + f"{archive_name}: stats post-processing finished in {time.time() - start:.2f}s" + ) def save_device_stream_thread_names(tmp_dir, report, output_queue): """ diff --git a/.github/container/nsys_jax/nsys_jax/visualization.py b/.github/container/nsys_jax/nsys_jax/visualization.py index 99530b3c7..5a5fc52fc 100644 --- a/.github/container/nsys_jax/nsys_jax/visualization.py +++ b/.github/container/nsys_jax/nsys_jax/visualization.py @@ -1,6 +1,5 @@ -from IPython.display import display, IFrame +from typing import Any, Iterable import subprocess -from typing import Iterable import xml.etree.ElementTree from .protobuf_utils import which @@ -8,7 +7,7 @@ def create_flamegraph( data: dict[Iterable[str], float], title: str, filename: str, width: int = 1200 -) -> tuple[str, IFrame]: +) -> tuple[str, Any]: """ Given a data structure of the form { @@ -26,6 +25,8 @@ def create_flamegraph( Returns a tuple (svg_data, InlineIFrame), where the latter can be passed to IPython.display.display(...) to be rendered inline in a Jupyter notebook. """ + from IPython.display import IFrame + flat_data = "" for loc, value in data.items(): assert not any(";" in x for x in loc) @@ -50,5 +51,7 @@ def create_flamegraph( def display_flamegraph(**kwargs): + from IPython.display import display + svg, iframe = create_flamegraph(**kwargs) display(iframe) diff --git a/.github/container/nsys_jax/tests/test_df_helpers.py b/.github/container/nsys_jax/tests/test_df_helpers.py new file mode 100644 index 000000000..704c5db46 --- /dev/null +++ b/.github/container/nsys_jax/tests/test_df_helpers.py @@ -0,0 +1,25 @@ +from nsys_jax.data_loaders import _find_overlapped as find_overlapped +import pandas as pd # type: ignore +import pytest # type: ignore + + +@pytest.mark.parametrize( + "records,expected", + [ + # no overlap + ([], []), + [[(0, 1)], []], + [[(0, 1), (1, 2)], []], + # overlap + [[(0, 1), (0.5, 1.5)], [0, 1]], + ([(0, 1), (0.5, 1.5), (8, 9), (9, 10)], [0, 1]), + ([(0, 1), (2, 3), (2.5, 3.5), (3, 4), (5, 6)], [1, 2, 3]), + ([(0, 1), (0.5, 1.5), (2, 3), (4.5, 5.5), (5, 6)], [0, 1, 3, 4]), + # overlap between non-neighbouring ranges + ([(0, 3), (0.1, 1), (2, 4), (5, 6)], [0, 1, 2]), # 0-1, 0-2 overlap but not 1-2 + ], +) +def test_find_overlapped(records, expected): + df = pd.DataFrame.from_records(records, columns=["start", "end"]) + result = find_overlapped(df["start"], df["end"]) + assert list(result) == expected diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index b4f3b8143..7ad146ccd 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -322,7 +322,9 @@ jobs: set -o pipefail num_tests=0 num_failures=0 - # Run the pytest-driven tests + # Run the pytest-driven tests; failure is explicitly handled below so set +e to + # avoid an early abort here. + set +e docker run -i --shm-size=1g --gpus all \ -v $PWD:/opt/output \ ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ @@ -333,6 +335,7 @@ jobs: test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())') pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}" EOF + set -e GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') for mode in 1-process 2-process process-per-gpu; do DOCKER="docker run --shm-size=1g --gpus all --env XLA_FLAGS=--xla_gpu_enable_command_buffer= --env XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -v ${PWD}:/opt/output ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}"