Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JetStream to MaxText container #1058

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions .github/container/Dockerfile.maxtext
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main
ARG SRC_PATH_MAXTEXT=/opt/maxtext
ARG URLREF_JETSTREAM=https://github.com/AI-Hypercomputer/JetStream.git#main
ARG SRC_PATH_JETSTREAM=/opt/jetstream

###############################################################################
## Download source and add auxiliary scripts
Expand All @@ -11,6 +13,8 @@ ARG SRC_PATH_MAXTEXT=/opt/maxtext
FROM ${BASE_IMAGE} as mealkit
ARG URLREF_MAXTEXT
ARG SRC_PATH_MAXTEXT
ARG URLREF_JETSTREAM
ARG SRC_PATH_JETSTREAM

RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
Expand All @@ -26,6 +30,12 @@ for pattern in \
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt;
done
echo "tensorflow-metadata>=1.15.0" >> ${SRC_PATH_MAXTEXT}/requirements.txt
# remove outdated version of JetStream in favor of the upstream JetStream
sed -i "/google-jetstream/d" ${SRC_PATH_MAXTEXT}/requirements.txt

# Add the upstream JetStream
git-clone.sh ${URLREF_JETSTREAM} ${SRC_PATH_JETSTREAM}
echo "-r ${SRC_PATH_JETSTREAM}/requirements.txt" >> /opt/pip-tools.d/requirements-jetstream.in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

${SRC_PATH_JETSTREAM}/requirements.txt
Is that part right? It don't seem right to me. It is the only case in JAX-Toolbox that we add a requirement file like this.

EOF

###############################################################################
Expand Down
5 changes: 5 additions & 0 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ maxtext:
tracking_ref: main
latest_verified_commit: 78daad198544def8274dbd656d122fbe6a0e1129
mode: git-clone
jetstream:
url: https://github.com/AI-Hypercomputer/JetStream.git
tracking_ref: main
latest_verified_commit: a3bb30e972a344675c20d56d8f61dc309e6a5f7b
mode: git-clone
levanter:
url: https://github.com/stanford-crfm/levanter.git
tracking_ref: main
Expand Down
33 changes: 33 additions & 0 deletions .github/workflows/_ci.yaml
yhtang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ jobs:
DOCKERFILE: .github/container/Dockerfile.maxtext
EXTRA_BUILD_ARGS: |
URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }}
URLREF_JETSTREAM=${{ fromJson(inputs.SOURCE_URLREFS).JETSTREAM }}
secrets: inherit

build-levanter:
Expand Down Expand Up @@ -689,3 +690,35 @@ jobs:
with:
MAXTEXT_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }}
secrets: inherit

test-jetstream:
needs: build-maxtext
if: inputs.ARCHITECTURE == 'amd64' # no arm64 gpu runners
uses: ./.github/workflows/_test_unit.yaml
with:
TEST_NAME: jetstream
EXECUTE: |
docker run --shm-size=1g --gpus all ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-jetstream.log
cd /opt/jetstream
pip install -r requirements.txt
export CUDA_VISIBLE_DEVICES=0
python -m unittest -v jetstream.tests.core.test_orchestrator
python -m jetstream.engine.mock_engine_test
python -m jetstream.core.orchestrator_test
python -m jetstream.core.server_test
EOF
STATISTICS_SCRIPT: |
summary_line=$(tail -n1 test-jetstream.log)
errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
failed_tests=$(echo $summary_line | grep -oE '[0-9]+ failed' | awk '{print $1} END { if (!NR) print 0}')
passed_tests=$(echo $summary_line | grep -oE '[0-9]+ passed' | awk '{print $1} END { if (!NR) print 0}')
total_tests=$((failed_tests + passed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
ARTIFACTS: |
test-jetstream.log
secrets: inherit

2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ on:
type: string
description: |
A comma-separated PACKAGE=URL#REF list to override sources used by build.
PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive)
PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,jetstream,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive)
default: ''
required: false

Expand Down
Loading