Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: optimise data loading and .zip creation #1193

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

olupton
Copy link
Collaborator

@olupton olupton commented Dec 10, 2024

Some rough measurements on vanilla jax-nccl-test and 8xH100:

Profile collection, whole execution: 52s (nsys), 58s (nsys-jax with this PR), 1m5s (nsys-jax without this PR)
Profile collection, restricted range: 46s (nsys), 50s (nsys-jax with this PR), 55s (nsys-jax without this PR)
Communication analysis, whole execution: 1.1s (with this PR), 2.1s (without this PR)
Communication analysis, restricted range: 1.0s (with this PR), 1.7s (without this PR)

The differences are more pronounced on larger workloads with more activity.

The two bigger changes are:

  • Convert .csv to .parquet as part of nsys-jax to avoid compressing .csv with Python's lzma module, which is slow and single-threaded. This speeds up nsys-jax and subsequent data-loading.
  • A new algorithm for calculating the hidden/exposed time of communication kernels when loading profile data -- essentially this adds a fast pandas-friendly pass to identify [most] non-overlapping kernels and skip running the [relatively slow and pandas-unfriendly] overlap calculation on them. This also removes an assumption that there is no compute-compute overlap.

Otherwise there are some tweaks to pandas usage and minor reorganisations to make Python profiles more informative, and minor bugfixes in the example Jupyter notebook.

@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch 6 times, most recently from 74a3d94 to d8056e0 Compare December 10, 2024 15:58
@olupton olupton requested a review from gspschmid December 11, 2024 09:56
@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch from 61357e9 to 4bf5946 Compare January 10, 2025 11:09
Copy link
Contributor

@gspschmid gspschmid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a few questions inline.

.transform(remove_program_id_and_name, axis="columns")
)
compile_df = compile_df.drop(columns=["EndMs"]).astype({"ProgramId": np.int32})
if len(compile_df):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks when compile_df is empty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes...for non-zero len(compile_df) the return value is a Series of that length, but for zero-length it changes to being an empty DataFrame. There is a result_type argument that can be used to get the right behaviour, but I thought it was less clear (from the docs it seems like it might be accidental that it helps anyway).

thunk_df["Communication"] = thunk_df.loc[:, ("Name",)].apply(
is_communication, axis=1
thunk_df["Communication"] = pd.Series(
data=map(is_communication, thunk_df["Name"].items()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why map(f, items) is preferable to df.apply(f) here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the profile data to hand anymore, but IIRC the apply approach had non-negligible overhead constructing the temporary passed to the mapped function 🤔

for comm_thunk in overlap_df.loc[
overlap_df["Communication"], ("ProjStartMs", "ProjEndMs")
].itertuples():
local_df = compute_df.loc[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could call this "fully overlapped by"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure quite what you mean by "fully" there? local_df contains all of the compute thunks whose execution at least partially overlaps with comm_thunk.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants