Skip to content

Commit

Permalink
Update pax/praxis patchlists for 24.04 release (#740)
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hang "Maxin" Tang <Tang.Maxin@gmail.com>
  • Loading branch information
ashors1 and yhtang authored Apr 19, 2024
1 parent e7b2a53 commit 4152544
Show file tree
Hide file tree
Showing 7 changed files with 1,427 additions and 41 deletions.
7 changes: 5 additions & 2 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ flax:
latest_verified_commit: 83d118ad369a470527e8ec6cd3b988fba5d4fd3e
mode: git-clone
patches:
# pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
# pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
transformer-engine:
url: https://github.com/NVIDIA/TransformerEngine.git
tracking_ref: release_v1.5
Expand All @@ -29,6 +29,8 @@ paxml:
mode: git-clone
patches:
pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support
mirror/patch/lora-and-sft-support: file://patches/paxml/mirror-patch-lora-and-sft-support.patch ## adds LLaMA SFT support
mirror/patch/moe-support: file://patches/paxml/mirror-patch-moe-support.patch ## adds GLaM support
praxis:
url: https://github.com/google/praxis.git
mirror_url: https://github.com/nvjax-svc-0/praxis.git
Expand All @@ -37,7 +39,8 @@ praxis:
mode: git-clone
patches:
pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas.
pull/36/head: file://patches/praxis/PR-36.patch # adds Transformer Engine support
mirror/patch/te-and-lora-support: file://patches/praxis/mirror-patch-te-and-lora-support.patch # adds Transformer Engine support, includes LoRA support
mirror/patch/glam_without_repeat_layer: file://patches/praxis/mirror-patch-glam_without_repeat_layer.patch # adds support for running GLaM models with repeat layer disabled
lingvo:
# Used only in ARM pax builds
url: https://github.com/tensorflow/lingvo.git
Expand Down
63 changes: 45 additions & 18 deletions .github/container/patches/paxml/PR-46.patch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
From 37461f7b414c3c40c8730b5c2c9318329b8bc2d6 Mon Sep 17 00:00:00 2001
From: ashors1 <ashors@nvidia.com>
Date: Tue, 18 Jul 2023 10:27:03 -0700
Subject: [PATCH 1/9] add TE support
Subject: [PATCH 01/10] add TE support

---
paxml/contrib/gpu/scripts_gpu/configs.py | 22 +-
Expand Down Expand Up @@ -654,13 +654,13 @@ index 6ec25e8..0342328 100644

train_state_partition_specs = (
--
2.25.1
2.34.1


From 371b48043de072908aca80ba8b16f34008bd875c Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Wed, 27 Sep 2023 10:46:53 +0800
Subject: [PATCH 2/9] Adding dropout support when enabling TE.
Subject: [PATCH 02/10] Adding dropout support when enabling TE.

---
paxml/contrib/gpu/scripts_gpu/te_helper.py | 10 ++++++++++
Expand Down Expand Up @@ -688,13 +688,13 @@ index d44ca67..2b9dba4 100644
assert self.packed_input == False
assert len(self.moe_layers) == 0
--
2.25.1
2.34.1


From 272e6352128962a9b2da10133737f3e1343bd36c Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Tue, 24 Oct 2023 10:30:27 +0800
Subject: [PATCH 3/9] Set deterministic=True for inference.
Subject: [PATCH 03/10] Set deterministic=True for inference.

---
paxml/contrib/gpu/scripts_gpu/te_helper.py | 3 ++-
Expand All @@ -715,13 +715,13 @@ index 2b9dba4..ef20305 100644
return x_out

--
2.25.1
2.34.1


From dfbf3a90cc0d93aa7d8e9c55c95ccf98c67f70bb Mon Sep 17 00:00:00 2001
From: Reese Wang <rewang@nvidia.com>
Date: Thu, 2 Nov 2023 22:04:58 -0700
Subject: [PATCH 4/9] Fix the excluded list for excluded_for_learner
Subject: [PATCH 04/10] Fix the excluded list for excluded_for_learner

Signed-off-by: Reese Wang <rewang@nvidia.com>
---
Expand All @@ -742,13 +742,13 @@ index 0342328..2e9bfd6 100644
vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt(
mdl_vars, excluded_for_learner
--
2.25.1
2.34.1


From 041456a8d4eb39350349101e263cb09f80b2b88c Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Tue, 7 Nov 2023 11:21:53 +0800
Subject: [PATCH 5/9] Adapting to TE/JAX/Custom_partitioning.
Subject: [PATCH 05/10] Adapting to TE/JAX/Custom_partitioning.

---
paxml/contrib/gpu/scripts_gpu/te_helper.py | 6 ++++--
Expand Down Expand Up @@ -779,13 +779,13 @@ index ef20305..fed1601 100644
finally:
pass
--
2.25.1
2.34.1


From 7d976d6510d8d5f751fd566ed2703bfa2d0a89d0 Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Tue, 7 Nov 2023 15:14:25 +0800
Subject: [PATCH 6/9] Adding TE-compatiable PipelinedTransformer
Subject: [PATCH 06/10] Adding TE-compatiable PipelinedTransformer

---
paxml/contrib/gpu/scripts_gpu/te_helper.py | 109 +++++++++++++++++++++
Expand Down Expand Up @@ -946,13 +946,13 @@ index fed1601..5914e54 100644
def update_fp8_metas_if_needed(mdl_vars, grads):
return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads)
--
2.25.1
2.34.1


From a1bb3c7d24817e1a77219f1cdfdb70b34157fda2 Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Wed, 8 Nov 2023 10:06:49 +0800
Subject: [PATCH 7/9] Apply OWG to TE's FP8 meta
Subject: [PATCH 07/10] Apply OWG to TE's FP8 meta

---
paxml/contrib/gpu/scripts_gpu/te_helper.py | 59 ----------------------
Expand Down Expand Up @@ -1107,13 +1107,13 @@ index 2e9bfd6..270fb3d 100644
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
--
2.25.1
2.34.1


From 0d07668f96ea4e106388fbfdc47ae228918ec135 Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Wed, 15 Nov 2023 14:43:17 +0800
Subject: [PATCH 8/9] Remove Praxis related setup (Moving to Praxis TE/Patch)
Subject: [PATCH 08/10] Remove Praxis related setup (Moving to Praxis TE/Patch)

---
paxml/contrib/gpu/scripts_gpu/configs.py | 9 -
Expand Down Expand Up @@ -1499,13 +1499,13 @@ index fd482df..b271258 100644
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
--
2.25.1
2.34.1


From f94e783f56454d758b8cbab81f5ce756835fd065 Mon Sep 17 00:00:00 2001
From: Ming-Xu Huang <mingh@nvidia.com>
Date: Wed, 15 Nov 2023 14:51:14 +0800
Subject: [PATCH 9/9] Fix missing DEFAULT_INIT_MUTABLE_LIST
Subject: [PATCH 09/10] Fix missing DEFAULT_INIT_MUTABLE_LIST

---
paxml/contrib/gpu/scripts_gpu/te_helper.py | 4 ++++
Expand Down Expand Up @@ -1534,5 +1534,32 @@ index b271258..cbac7cf 100644

class TransformerEngineHelperBase:
--
2.25.1
2.34.1


From ed3300f065b5dd5e02f7e30535c83bcc0cac20bc Mon Sep 17 00:00:00 2001
From: Hemil Desai <hemild@nvidia.com>
Date: Mon, 12 Feb 2024 10:22:15 -0800
Subject: [PATCH 10/10] Revert mutable kwarg in abstract_init_with_metadata in
init checkpoint rule

---
paxml/tasks_lib.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py
index e475099..43e090c 100644
--- a/paxml/tasks_lib.py
+++ b/paxml/tasks_lib.py
@@ -1787,7 +1787,7 @@ class SingleTask(base_task.BaseTask):
)
# Initialize with a dummy seed
var_weight_hparams = ckpt_task.model.abstract_init_with_metadata(
- inputs_shape_dtype, mutable=DEFAULT_INIT_MUTABLE_LIST)
+ inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
ckpt_train_state = ckpt_task.create_train_state_padded_shapes(
var_weight_hparams)
train_state_pspecs = ckpt_task.create_train_state_partition_specs(
--
2.34.1

Loading

0 comments on commit 4152544

Please sign in to comment.