Skip to content

Commit

Permalink
Hotfix xla runtime (#476)
Browse files Browse the repository at this point in the history
Merging the hotfix from #472
into the 24.01-devel branch since the nccl error observed would also be
present here

@ashors1 @hemildesai

---------

Co-authored-by: Vladislav <vkozlov@nvidia.com>
Co-authored-by: Yu-Hang 'Maxin' Tang <Tang.Maxin@gmail.com>
Co-authored-by: Olli Lupton <olupton@nvidia.com>
Co-authored-by: ashors1 <71393111+ashors1@users.noreply.github.com>
  • Loading branch information
5 people authored Jan 16, 2024
1 parent bc38906 commit 5106983
Show file tree
Hide file tree
Showing 27 changed files with 88 additions and 78 deletions.
3 changes: 2 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ ARG BUILD_DATE

ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false"
# TODO: --xla_gpu_enable_xla_runtime_executable=true should be removed ASAP. It is a WAR for a crash observed in the default runtime leading to this error message: "nccl_all_reduce_thunk.cc:175] Check failed: reduction_kind.has_value()"
ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_xla_runtime_executable=true"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_IB_SL=1
ENV NCCL_NVLS_ENABLE=0
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/_build_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ on:
description: 'User email in GIT to perform git pull/push'
required: false
default: 'jax@nvidia.com'
TRIAL_BRANCH:
type: string
description: 'Name of branch with bumped manifest and patches'
required: true

outputs:
DOCKER_TAG:
Expand Down Expand Up @@ -65,6 +69,9 @@ jobs:

- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v3
with:
ref: ${{ inputs.TRIAL_BRANCH }}


- name: Login to GitHub Container Registry
uses: docker/login-action@v2
Expand Down
6 changes: 0 additions & 6 deletions .github/workflows/_build_jax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ on:
description: 'User email in GIT to perform git pull/push'
required: false
default: 'jax@nvidia.com'
TRIAL_BRANCH:
type: string
description: 'Name of branch with bumped manifest and patches'
required: true
outputs:
DOCKER_TAG_MEALKIT:
description: "Tags of the 'mealkit' image built"
Expand Down Expand Up @@ -87,8 +83,6 @@ jobs:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v3
with:
ref: ${{ inputs.TRIAL_BRANCH }}

- name: Login to GitHub Container Registry
uses: docker/login-action@v2
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: ~CI, single-arch
run-name: CI-${{ inputs.ARCHITECTURE }}

on:
Expand Down Expand Up @@ -56,6 +56,7 @@ jobs:
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
TRIAL_BRANCH: ${{ inputs.TRIAL_BRANCH }}
secrets: inherit

build-jax:
Expand All @@ -65,7 +66,6 @@ jobs:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAG }}
TRIAL_BRANCH: ${{ inputs.TRIAL_BRANCH }}
secrets: inherit

build-t5x:
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.00045988122777392465, 2.1979689942478824e-05, 1.3908903990037895e-06], "step_times": [9.137898763020834, 9.141323407491049, 9.186994552612305], "step_time_avg": 9.155405574374731, "e2e_time_seconds": 292.8193333333333}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004531083977781236, 2.0986779418308288e-05, 1.31147601223347e-06], "step_times": [9.16663678487142, 9.166899998982748, 9.190437952677408], "step_time_avg": 9.174658245510523, "e2e_time_seconds": 286.86566666666664}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.00046117109013721347, 2.1980367819196545e-05, 1.3908905505862397e-06], "step_times": [9.571099281311035, 9.56772263844808, 9.569978396097818], "step_time_avg": 9.569600105285645, "e2e_time_seconds": 195.529}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004550429875962436, 2.098800177918747e-05, 1.3114761259203078e-06], "step_times": [9.69591999053955, 9.694547653198242, 9.6983060836792], "step_time_avg": 9.696257909138998, "e2e_time_seconds": 193.862}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.0004496203036978841, 2.1072781237307936e-05, 1.311417690885719e-06, 6.332992796842518e-08, 0.0], "step_times": [7.595651865005493, 7.599909543991089, 7.602108001708984, 7.595033645629883, 7.59737229347229], "step_time_avg": 7.598015069961548, "e2e_time_seconds": 234.1085}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004619551570309947, 2.1979879723706592e-05, 1.3908903990037895e-06], "step_times": [9.800720850626627, 9.806367556254068, 9.80089282989502], "step_time_avg": 9.802660412258573, "e2e_time_seconds": 287.28933333333333}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004562355752568692, 2.0987012248951942e-05, 1.31147601223347e-06], "step_times": [9.896685918172201, 9.902073860168457, 9.896770795186361], "step_time_avg": 9.898510191175673, "e2e_time_seconds": 282.402}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.000438039714936167, 2.219043562945444e-05, 1.4306265256891493e-06], "step_times": [2.4758432706197104, 2.476098378499349, 2.4765774408976235], "step_time_avg": 2.4761730300055604, "e2e_time_seconds": 322.90933333333334}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.000438039714936167, 2.219043562945444e-05, 1.4306265256891493e-06], "step_times": [2.470784823099772, 2.471130927403768, 2.471168835957845], "step_time_avg": 2.4710281954871287, "e2e_time_seconds": 316.19}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004382363404147327, 2.253897037007846e-05, 1.4306265256891493e-06], "step_times": [1.6547863880793254, 1.6562701066335042, 1.6603738069534302], "step_time_avg": 1.6571434338887532, "e2e_time_seconds": 391.36400000000003}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004382363404147327, 2.253897037007846e-05, 1.4306265256891493e-06], "step_times": [1.6266847054163616, 1.6370126803716023, 1.6351629098256428], "step_time_avg": 1.6329534318712022, "e2e_time_seconds": 393.96866666666665}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004683640436269343, 2.1106745407450944e-05, 1.311883579546702e-06], "step_times": [6.425331115722656, 6.425313472747803, 6.422374884287517], "step_time_avg": 6.424339824252658, "e2e_time_seconds": 296.25399999999996}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004683640436269343, 2.1106745407450944e-05, 1.311883579546702e-06], "step_times": [6.465008894602458, 6.465569972991943, 6.463742891947429], "step_time_avg": 6.464773919847276, "e2e_time_seconds": 285.65200000000004}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004595223581418395, 2.2086765966378152e-05, 1.450262061553076e-06], "step_times": [8.208735148111979, 8.21290397644043, 8.209474563598633], "step_time_avg": 8.21037122938368, "e2e_time_seconds": 296.959}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004528434365056455, 2.122193473041989e-05, 1.3954693258710904e-06], "step_times": [8.431368192036947, 8.43382708231608, 8.431408246358236], "step_time_avg": 8.432201173570421, "e2e_time_seconds": 287.23}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004681030404753983, 2.086862332362216e-05, 1.3121162965035182e-06], "step_times": [7.2736771901448565, 7.274514834086101, 7.268196900685628], "step_time_avg": 7.272129641638862, "e2e_time_seconds": 278.86633333333333}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004681030404753983, 2.086862332362216e-05, 1.3121162965035182e-06], "step_times": [7.275770664215088, 7.267403920491536, 7.267686367034912], "step_time_avg": 7.270286983913845, "e2e_time_seconds": 273.26233333333334}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004598816449288279, 2.197976027673576e-05, 1.3908903990037895e-06], "step_times": [9.448389371236166, 9.448740005493164, 9.44986343383789], "step_time_avg": 9.448997603522407, "e2e_time_seconds": 287.34200000000004}
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0004531083977781236, 2.0986779418308288e-05, 1.31147601223347e-06], "step_times": [9.474847793579102, 9.474331537882486, 9.4735320409139], "step_time_avg": 9.474237124125162, "e2e_time_seconds": 281.1556666666667}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"start_step": 100, "end_step": 300, "step_interval": 100, "loss_values": [0.0005329704144969583, 2.392743408563547e-05, 1.6679570080668782e-06], "step_times": [7.743874549865723, 7.7442946434021, 7.74159049987793], "step_time_avg": 7.743253231048584, "e2e_time_seconds": 307.527}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step":100,"end_step":500,"step_interval":100,"loss_values":[0.00046839181838246685,2.087271119914173e-05,1.31276870736959e-06,6.912159155233096e-11,0],"step_times":[6.357304414113362,5.979689915974935,6.376240253448486,6.373825391133626,6.355693658192952],"step_time_avg":6.288550726572673,"e2e_time_seconds":295.73600000000005,"run_urls":["https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160692471/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160694203/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160696525/artifacts"],"date":"2023-07-18"}
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.0004681030404753983, 2.086862332362216e-05, 1.3121162965035182e-06, 5.8207657444020455e-11, 0.0], "step_times": [7.725894292195638, 7.695748964945476, 7.674357891082764, 7.7014509836832685, 7.720887184143066], "step_time_avg": 7.703667863210043, "e2e_time_seconds": 295.4576666666666}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step":100,"end_step":500,"step_interval":100,"loss_values":[0.0004682748403865844,2.090286701180351e-05,1.3127760970140419e-06,5.8207657444020455e-11,0],"step_times":[6.688000361124675,6.699192523956299,6.694862047831218,6.698123772939046,6.700749556223552],"step_time_avg":6.6961856524149574,"e2e_time_seconds":223.268,"run_urls":["https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160692471/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160694203/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160696525/artifacts"],"date":"2023-07-18"}
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.00046809131163172424, 2.086867971229367e-05, 1.3123492408340098e-06, 5.8207657444020455e-11, 0.0], "step_times": [8.008778889973959, 8.014708836873373, 8.011429150899252, 8.013259251912435, 8.00814119974772], "step_time_avg": 8.011263465881347, "e2e_time_seconds": 204.86966666666663}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.0004679674166254699, 2.0868097635684535e-05, 1.3118251445121132e-06, 6.335539382007482e-08, 0.0], "step_times": [6.345967451731364, 6.3443193435668945, 6.345146497090657, 6.344050407409668, 6.3422525723775225], "step_time_avg": 6.344347254435221, "e2e_time_seconds": 238.7543333333333}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.00046810254571028054, 2.0870078515144996e-05, 1.3122327118253452e-06, 5.8207657444020455e-11, 0.0], "step_times": [8.161738077799479, 8.162349383036295, 8.15965493520101, 8.158018112182617, 8.157390912373861], "step_time_avg": 8.159830284118652, "e2e_time_seconds": 296.75933333333336}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step":100,"end_step":500,"step_interval":100,"loss_values":[0.00043803153675980866,2.2190377421793528e-05,1.4306265256891493e-06,5.8207657444020455e-11,0],"step_times":[2.357959032058716,2.3574414253234863,2.3560804526011148,2.357269843419393,2.3561060428619385],"step_time_avg":2.3569713592529298,"e2e_time_seconds":385.921,"run_urls":["https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160692471/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160694203/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160696525/artifacts"],"date":"2023-07-18"}
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.000438039714936167, 2.219043562945444e-05, 1.4306265256891493e-06, 5.8207657444020455e-11, 0.0], "step_times": [2.5234107971191406, 2.5232578118642173, 2.5235915184020996, 2.5234344005584717, 2.5233071645100913], "step_time_avg": 2.523400338490804, "e2e_time_seconds": 388.4533333333334}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.0004382363404147327, 2.253897037007846e-05, 1.4306265256891493e-06, 5.8207657444020455e-11, 0.0], "step_times": [1.6573359568913777, 1.654531757036845, 1.6514259179433186, 1.652300516764323, 1.6523643334706624], "step_time_avg": 1.653591696421305, "e2e_time_seconds": 507.07099999999997}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.0004683640436269343, 2.1106745407450944e-05, 1.311883579546702e-06, 5.8207657444020455e-11, 0.0], "step_times": [6.97843599319458, 6.977321783701579, 6.975888093312581, 6.976069609324138, 6.976081212361653], "step_time_avg": 6.976759338378906, "e2e_time_seconds": 306.4506666666667}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"start_step":100,"end_step":500,"step_interval":100,"loss_values":[0.00046849539891506237,2.0879013391095214e-05,1.3132464952529215e-06,5.8207657444020455e-11,0],"step_times":[6.436240037282308,6.217730363210042,6.462920983632405,6.463934898376465,6.473924477895101],"step_time_avg":6.4109501520792636,"e2e_time_seconds":284.0213333333333,"run_urls":["https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160692471/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160694203/artifacts","https://api.github.com/repos/NVIDIA/JAX-Toolbox/actions/runs/5160696525/artifacts"],"date":"2023-07-18"}
{"start_step": 100, "end_step": 500, "step_interval": 100, "loss_values": [0.0004681030404753983, 2.086862332362216e-05, 1.3121162965035182e-06, 5.8207657444020455e-11, 0.0], "step_times": [7.965368588765462, 7.962141513824463, 7.961517333984375, 7.960983753204346, 7.957266171773274], "step_time_avg": 7.961455472310384, "e2e_time_seconds": 291.03866666666664}
2 changes: 1 addition & 1 deletion .github/workflows/baselines/test_pax_mgmn_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_loss(baseline_filename):
with open(baseline_filepath, "r") as baseline_file:
end_step = json.load(baseline_file)["end_step"]
loss_actual = test_utils.read_tb_tag(event_file, loss_summary_name)
assert loss_actual[end_step] < 1.6e-6, f"Loss at final step: {loss_actual[end_step]}, Expected loss < 1.6e-6"
assert 0 <= loss_actual[end_step] < 1.8e-6, f"Loss at final step: {loss_actual[end_step]}, Expected 0 <= loss < 1.8e-6"


@pytest.mark.parametrize("baseline_filename", os.listdir(baselines_dir))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: Weekly base container build
run-name: Weekly base container build (${{ github.event_name == 'workflow_run' && format('nightly {0}', github.event.workflow_run.created_at) || github.event_name }})
name: Nightly base container build
run-name: Nightly base container build (${{ github.event_name == 'workflow_run' && format('nightly {0}', github.event.workflow_run.created_at) || github.event_name }})

on:
schedule:
- cron: '30 13 * * 5' # 5:30 AM every Friday
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC
workflow_dispatch:
inputs:
PUBLISH:
Expand All @@ -13,7 +13,7 @@ on:
required: false

permissions:
contents: read # to fetch code
contents: write # to push trial branch
actions: write # to cancel previous workflows
packages: write # to upload container

Expand All @@ -38,20 +38,62 @@ jobs:
run: |
echo "PUBLISH=${{ github.event_name == 'schedule' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT
amd64:
bump-world-state:
needs: metadata
outputs:
TRIAL_BRANCH: ${{ steps.trial-meta.outputs.TRIAL_BRANCH }}
runs-on: ubuntu-22.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v3
- name: Update manifest and patches in-place - show diff
working-directory: .github/container
shell: bash -x -e {0}
run: |
bash bump.sh --input-manifest manifest.yaml
git diff
- name: Push trial branch
id: trial-meta
if: needs.metadata.outputs.PUBLISH == 'true'
shell: bash -x -e {0}
run: |
git config user.name "JAX-Toolbox CI"
git config user.email "jax@nvidia.com"
# Prepend trial branch with "z" to make it appear at the end
trial_branch=znightly-${{ needs.metadata.outputs.BUILD_DATE }}-${{ github.run_id }}
git switch -c $trial_branch
git status
git add -u
git add .github/container/patches/
git status
git commit -m "generated ${{github.run_id }}"
git push --set-upstream origin $trial_branch
echo "$trial_branch" > ./trial-branch.txt
echo TRIAL_BRANCH=$trial_branch | tee $GITHUB_OUTPUT
- name: 'Upload trial branch artifact'
if: needs.metadata.outputs.PUBLISH == 'true'
uses: actions/upload-artifact@v3
with:
name: trial-branch
path: trial-branch.txt

amd64:
needs: [metadata, bump-world-state]
uses: ./.github/workflows/_build_base.yaml
with:
ARCHITECTURE: amd64
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
TRIAL_BRANCH: ${{ needs.bump-world-state.outputs.TRIAL_BRANCH }}
secrets: inherit

arm64:
needs: metadata
needs: [metadata, bump-world-state]
uses: ./.github/workflows/_build_base.yaml
with:
ARCHITECTURE: arm64
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
TRIAL_BRANCH: ${{ needs.bump-world-state.outputs.TRIAL_BRANCH }}
secrets: inherit

publish:
Expand Down
Loading

0 comments on commit 5106983

Please sign in to comment.