Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize NCCL for base container #1123

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ FROM ${BASE_IMAGE}
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME
ARG CLANG_VERSION
ARG JAX_NCCL_VERSION
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would a default be specified? I think the situation today is that we want the default nightly/CI behaviour to be "12.6.1 base image with NCCL 2.23 installed on top", so it should be possible to express that.

I know that for the base image, there is special treatment of latest to enable this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the "special treatment" is to embed a default in the Dockerfile, and then conditionally override it via an Actions conditional expression:

${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }}

ARG JAX_LIBNCCL_PACKAGE

###############################################################################
## Update NCCL version env variables
###############################################################################

ENV NV_LIBNCCL_DEV_PACKAGE=${NV_LIBNCCL_DEV_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE}
ENV NV_LIBNCCL_DEV_PACKAGE_VERSION=${JAX_NCCL_VERSION}
ENV NCCL_VERSION=${JAX_NCCL_VERSION}
ENV NV_LIBNCCL_PACKAGE=${NV_LIBNCCL_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE}
ENV NV_LIBNCCL_PACKAGE_VERSION=${JAX_NCCL_VERSION}

###############################################################################
## Install Python and essential tools
Expand Down
24 changes: 18 additions & 6 deletions .github/container/install-nccl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ set -ex -o pipefail
export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles

# If NCCL is already installed, don't reinstall it. Print a message and exit
if dpkg -s libnccl2 libnccl-dev &> /dev/null; then
echo "NCCL is already installed. Skipping installation."
# Try to get NCCL_VERSION of installed libnccl-dev
if [[ -z $NCCL_VERSION ]]; then
NCCL_VERSION=$(dpkg -s libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1 | tr "+" "\n" | head -1)
fi

# Skip NCCL installation if both JAX_NCCL_VERSION (user defined) and
# NCCL_VERSION (defined in nvidia/cuda containers) are unset.
# This case means that the base container is built from a custom image with
# a custom network communicator or unset NCCL_VERSION env variable.
if [[ -z $JAX_NCCL_VERSION && -z $NCCL_VERSION ]]; then
echo "Skip NCCL installation"
else
JAX_NCCL_VERSION=${JAX_NCCL_VERSION:-$NCCL_VERSION}
apt-get update

# Extract CUDA version from `nvcc --version` output line
Expand All @@ -18,21 +27,24 @@ else

# Find latest NCCL version compatible with existing CUDA by matching
# ${cuda_version} in the package version string
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1)
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1)
if [[ -z "${libnccl2_version}" || -z "${libnccl_dev_version}" ]]; then
echo "Could not find compatible NCCL version for CUDA ${cuda_version}"
exit 1
fi

apt-get install -y \
apt-get install -y --allow-change-held-packages \
libnccl2=${libnccl2_version} \
libnccl-dev=${libnccl_dev_version}

apt-get clean
rm -rf /var/lib/apt/lists/*
fi

# Smoke test of installed NCCL packages
dpkg -s libnccl2 libnccl-dev

# Create a prefix with include/ and lib/ directories containing symlinks to the NCCL
# version installed at the system level; this is useful to pass to XLA to avoid it
# fetching its own copy.
Expand Down
45 changes: 44 additions & 1 deletion .github/workflows/_build_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ on:
description: Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch
default: ''
required: false
JAX_LIBNCCL_PACKAGE:
type: string
description: NCCL lib package version to be installed (in the format `2.19.3-1+cuda12.3`)
default: ''
required: false
outputs:
DOCKER_TAG:
description: "Tag of the image built"
Expand All @@ -56,8 +61,44 @@ permissions:
packages: write # to upload container

jobs:
nccl-version:
runs-on: ubuntu-22.04
outputs:
JAX_NCCL_VERSION: ${{ steps.get-nccl-version.outputs.JAX_NCCL_VERSION }}
JAX_LIBNCCL_PACKAGE: ${{ steps.get-nccl-version.outputs.JAX_LIBNCCL_PACKAGE }}
steps:
- name: Print environment variables
run: env

- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4

- name: Get NCCL version
id: get-nccl-version
shell: bash -x -e {0}
run: |
JAX_LIBNCCL_PACKAGE=${{ inputs.JAX_LIBNCCL_PACKAGE }}
if [[ -z $JAX_LIBNCCL_PACKAGE ]]; then
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
if [[ $BASE_IMAGE == latest ]]; then
BASE_IMAGE=$(cat .github/container/Dockerfile.base | sed -n "s/^ARG BASE_IMAGE=\(.*\)$/\1/p")
fi
# try to get NCCL version from provided BASE_IMAGE of x86-arch
if [[ -z "$BASE_IMAGE" ]]; then
echo "Need to pass non-empty BASE_IMAGE variable"
exit 1
fi
source .github/workflows/scripts/get_remote_env.sh
JAX_LIBNCCL_PACKAGE=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NV_LIBNCCL_PACKAGE')
JAX_NCCL_VERSION=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NCCL_VERSION=' | cut -d= -f2-)
else
JAX_NCCL_VERSION=$(echo $JAX_LIBNCCL_PACKAGE | cut -d= -f2 | cut -d+ -f1)
fi
echo "JAX_NCCL_VERSION=$JAX_NCCL_VERSION" >> $GITHUB_OUTPUT
echo "JAX_LIBNCCL_PACKAGE=$JAX_LIBNCCL_PACKAGE" >> $GITHUB_OUTPUT

build-base:
needs: nccl-version
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small]
env:
BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json
Expand Down Expand Up @@ -133,7 +174,9 @@ jobs:
GIT_USER_EMAIL=${{ inputs.GIT_USER_EMAIL }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }}

JAX_NCCL_VERSION=${{ needs.nccl-version.outputs.JAX_NCCL_VERSION }}
JAX_LIBNCCL_PACKAGE=${{ needs.nccl-version.outputs.JAX_LIBNCCL_PACKAGE }}

- name: Generate sitrep
if: "!cancelled()"
shell: bash -x -e {0}
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ on:
description: 'A JSON object containing git url+refs for softwares to be built'
required: false
default: '{}'
JAX_LIBNCCL_PACKAGE:
type: string
description: NCCL version to be installed (for example, `2.20.3-1+cuda12.4`)
default: ''
required: false
outputs:
DOCKER_TAGS:
description: 'JSON object containing tags of all docker images built'
Expand All @@ -45,6 +50,7 @@ jobs:
BASE_IMAGE: ${{ inputs.CUDA_IMAGE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
secrets: inherit

build-jax:
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ on:
PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive)
default: ''
required: false
JAX_LIBNCCL_PACKAGE:
type: string
description: NCCL version to be installed (for example, 2.20.3-1+cuda12.4)
default: ''
required: false

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down Expand Up @@ -197,6 +202,7 @@ jobs:
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
secrets: inherit

arm64:
Expand All @@ -208,6 +214,7 @@ jobs:
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
secrets: inherit

# Only merge if everything succeeds
Expand Down
Loading