diff --git a/.github/container/Dockerfile.maxtext b/.github/container/Dockerfile.maxtext index 6694aa821..0ccc1d835 100644 --- a/.github/container/Dockerfile.maxtext +++ b/.github/container/Dockerfile.maxtext @@ -3,6 +3,8 @@ ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main ARG SRC_PATH_MAXTEXT=/opt/maxtext +ARG URLREF_JETSTREAM=https://github.com/AI-Hypercomputer/JetStream.git#main +ARG SRC_PATH_JETSTREAM=/opt/jetstream ############################################################################### ## Download source and add auxiliary scripts @@ -11,6 +13,8 @@ ARG SRC_PATH_MAXTEXT=/opt/maxtext FROM ${BASE_IMAGE} as mealkit ARG URLREF_MAXTEXT ARG SRC_PATH_MAXTEXT +ARG URLREF_JETSTREAM +ARG SRC_PATH_JETSTREAM RUN <<"EOF" bash -ex git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} @@ -26,6 +30,12 @@ for pattern in \ sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt; done echo "tensorflow-metadata>=1.15.0" >> ${SRC_PATH_MAXTEXT}/requirements.txt +# remove outdated version of JetStream in favor of the upstream JetStream +sed -i "/google-jetstream/d" ${SRC_PATH_MAXTEXT}/requirements.txt + +# Add the upstream JetStream +git-clone.sh ${URLREF_JETSTREAM} ${SRC_PATH_JETSTREAM} +echo "-r ${SRC_PATH_JETSTREAM}/requirements.txt" >> /opt/pip-tools.d/requirements-jetstream.in EOF ############################################################################### diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 433871c14..36705ae3f 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -109,6 +109,11 @@ maxtext: tracking_ref: main latest_verified_commit: 78daad198544def8274dbd656d122fbe6a0e1129 mode: git-clone +jetstream: + url: https://github.com/AI-Hypercomputer/JetStream.git + tracking_ref: main + latest_verified_commit: a3bb30e972a344675c20d56d8f61dc309e6a5f7b + mode: git-clone levanter: url: https://github.com/stanford-crfm/levanter.git tracking_ref: main diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index c2b4cb4ee..688c93981 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -111,6 +111,7 @@ jobs: DOCKERFILE: .github/container/Dockerfile.maxtext EXTRA_BUILD_ARGS: | URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }} + URLREF_JETSTREAM=${{ fromJson(inputs.SOURCE_URLREFS).JETSTREAM }} secrets: inherit build-levanter: @@ -689,3 +690,35 @@ jobs: with: MAXTEXT_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }} secrets: inherit + + test-jetstream: + needs: build-maxtext + if: inputs.ARCHITECTURE == 'amd64' # no arm64 gpu runners + uses: ./.github/workflows/_test_unit.yaml + with: + TEST_NAME: jetstream + EXECUTE: | + docker run --shm-size=1g --gpus all ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }} \ + bash <<"EOF" |& tee test-jetstream.log + cd /opt/jetstream + pip install -r requirements.txt + export CUDA_VISIBLE_DEVICES=0 + python -m unittest -v jetstream.tests.core.test_orchestrator + python -m jetstream.engine.mock_engine_test + python -m jetstream.core.orchestrator_test + python -m jetstream.core.server_test + EOF + STATISTICS_SCRIPT: | + summary_line=$(tail -n1 test-jetstream.log) + errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}') + failed_tests=$(echo $summary_line | grep -oE '[0-9]+ failed' | awk '{print $1} END { if (!NR) print 0}') + passed_tests=$(echo $summary_line | grep -oE '[0-9]+ passed' | awk '{print $1} END { if (!NR) print 0}') + total_tests=$((failed_tests + passed_tests)) + echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT + echo "ERRORS=${errors}" >> $GITHUB_OUTPUT + echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT + echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT + ARTIFACTS: | + test-jetstream.log + secrets: inherit + diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0c3c8bdb0..0d0ffce05 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -37,7 +37,7 @@ on: type: string description: | A comma-separated PACKAGE=URL#REF list to override sources used by build. - PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive) + PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,jetstream,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive) default: '' required: false