-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
nsys-jax: re-work to be more pip-install-able (#1165)
The overarching goal of this PR is to get closer to a world where the `nsys-jax` tooling is straightforwardly `pip install`-able. While the diff looks scary, it's mostly re-organisation. Substantive changes: - `nsys-jax` no longer bundles Python code in the output archives, the `install.sh` script provided for users to run on local machines becomes, loosely, `install 'pip nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax'`, where `COMMIT` corresponds to the `nsys-jax` command that produced the archive. For the `ghcr.io/nvidia/jax` containers, this is the commit of JAX-Toolbox that triggered the container build. Changes included: - Introduce `/opt/pip-tools-post-install.d`, which `pip-finalize.sh` will execute the contents of *after* installing the `pip`-managed world - Migrate `install-protoc` to use this, so `pip-finalize.sh` can forget about that detail. - Install https://github.com/brendangregg/FlameGraph/blob/master/flamegraph.pl via this. - Patch the `nvtx_gpu_proj_trace` Python code in Nsight Systems 2024.5 and 2024.6 via this. - Move `nsys-jax` installation (specifically for the containers) into `install-nsys-jax.sh` and thereby clean up `install-nsight.sh`. The new script has to be told the git commit hash of JAX-Toolbox that is being built, because `nsys-jax` bakes this into an installation script in its output `.zip` archives to ensure the local environment matches the profile-collection environment. - The CLI tools like `nsys-jax`, `nsys-jax-combine` and `install-protoc` are now handled via `[project.scripts]` in `pyproject.toml` instead of being standalone Python scripts. This is "more standard", and also makes it easier to share code between `nsys-jax` and `nsys-jax-combine`. - The Python library is renamed from `jax_nsys` to `nsys_jax` for consistency. - It's now possible to set the default data loading path via the `NSYS_JAX_DEFAULT_PREFIX` environment variable; previously the default was the current working directory, but that can be inconvenient to steer in Jupyter environments.
- Loading branch information
Showing
36 changed files
with
1,497 additions
and
1,283 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
set -exo pipefail | ||
|
||
REF="$1" | ||
if [[ -z "${REF}" ]]; then | ||
echo "$0: <git ref of JAX-Toolbox>" | ||
exit 1 | ||
fi | ||
|
||
# Install extra dependencies needed for `nsys recipe ...` commands. These are | ||
# used by the nsys-jax wrapper script. | ||
NSYS_DIR=$(dirname $(realpath $(command -v nsys))) | ||
ln -s ${NSYS_DIR}/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in | ||
|
||
# Install the nsys-jax package, which includes nsys-jax, nsys-jax-combine, | ||
# install-protoc (called from pip-finalize.sh), and nsys-jax-patch-nsys as well as the | ||
# nsys_jax Python library. | ||
URL="git+https://github.com/NVIDIA/JAX-Toolbox.git@${REF}#subdirectory=.github/container/nsys_jax&egg=nsys-jax" | ||
echo "-e '${URL}'" > /opt/pip-tools.d/requirements-nsys-jax.in | ||
|
||
# protobuf will be installed at least as a dependency of nsys_jax in the base | ||
# image, but the installed version is likely to be influenced by other packages. | ||
echo "install-protoc /usr/local" > /opt/pip-tools-post-install.d/protoc | ||
chmod 755 /opt/pip-tools-post-install.d/protoc | ||
|
||
# Make sure flamegraph.pl is available | ||
echo "install-flamegraph /usr/local" > /opt/pip-tools-post-install.d/flamegraph | ||
chmod 755 /opt/pip-tools-post-install.d/flamegraph | ||
|
||
# Make sure Nsight Systems Python patches are installed if needed | ||
echo "nsys-jax-patch-nsys" > /opt/pip-tools-post-install.d/patch-nsys | ||
chmod 755 /opt/pip-tools-post-install.d/patch-nsys |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.