From 056a3b0db2e34c497f7984e54bb504d9b33efe58 Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Wed, 18 Sep 2024 15:56:13 -0700 Subject: [PATCH] Add an option to test-pax.sh to enable XLA cuDNN flash attention (#1045) Provide an option to run XLA cuDNN flash attention as an alternative to TE cuDNN flash attention. --- .github/container/test-pax.sh | 49 +++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/.github/container/test-pax.sh b/.github/container/test-pax.sh index 2b33f53f7..46ce6ae73 100755 --- a/.github/container/test-pax.sh +++ b/.github/container/test-pax.sh @@ -15,7 +15,8 @@ usage() { echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py" echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4." echo " --dtype Batch size, defaults to bfloat16." - echo " --enable-te If set, will run with env var ENABLE_TE=1." + echo " --enable-te If set, will run with env var ENABLE_TE=1." + echo " --enable-cudnn-fa If set, will use cudnn fa." echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." echo " --disable-fused-attn Whether disable TE fused attention." echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M" @@ -26,13 +27,13 @@ usage() { echo " --data-parallel Data parallelism to use. Defaults to 1." echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1." echo " --tensor-parallel Tensor parallelism to use. Defaults to 1." - echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining." + echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining." echo " -n, --nodes Number of nodes." echo " -h, --help Print usage." exit $1 } -args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") +args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") if [[ $? -ne 0 ]]; then exit $1 fi @@ -50,6 +51,7 @@ TP=1 PP=1 NODES=1 ENABLE_TE=0 +ENABLE_CUDNN_FA=0 MODEL_TYPE=126M NVTE_FUSED_ATTN=1 DROPOUT=0 @@ -75,6 +77,10 @@ while [ : ]; do ENABLE_TE=1 shift 1 ;; + --enable-cudnn-fa) + ENABLE_CUDNN_FA=1 + shift 1 + ;; --enable-dropout) DROPOUT='0.1' shift 1 @@ -128,7 +134,7 @@ while [ : ]; do ;; --) shift; - break + break ;; *) echo "UNKNOWN OPTION $1" @@ -149,6 +155,7 @@ print_var NGPUS print_var OUTPUT print_var MULTIPROCESS print_var ENABLE_TE +print_var ENABLE_CUDNN_FA print_var NVTE_FUSED_ATTN print_var EVALUATE print_var DROPOUT @@ -196,10 +203,10 @@ if dcn_factor > 1: if dp % dcn_factor == 0: dcn_dp = dcn_factor dp = int(dp / dcn_factor) - elif fsdp % dcn_factor == 0: + elif fsdp % dcn_factor == 0: dcn_fsdp = dcn_factor fsdp = int(fsdp / dcn_factor) - elif pp % dcn_factor == 0: + elif pp % dcn_factor == 0: dcn_pp = dcn_factor pp = int(pp / dcn_factor) @@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): USE_REPEATED_LAYER = False ICI_MESH_SHAPE = [64,1,1] MAX_STEPS = 600000 - + MAX_SEQ_LEN = 2048 VOCAB_SIZE = 50304 PACKED_INPUT = True PERCORE_BATCH_SIZE = 4 - + NUM_LAYERS = 12 NUM_HEADS = 12 MODEL_DIMS = 768 @@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): TRAINABLE_POSITION_EMB = True TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN - + USE_BIAS = True LAYERNORM_EPSILON = 1e-5 ATTEN_LOGIT_CAP = -1.0 INIT_STD = 0.023 SOFTMAX_INIT_STD = 0.023 ACTIVATION_CLS = layers.GELU - + ## optimizer-related ADAM_BETA1 = 0.9 ADAM_BETA2 = 0.95 @@ -255,7 +262,7 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): ## disable eval to avoid including eval ## in steps/sec calculation EVAL_INTERVAL_STEPS = 100000 - + def task(self): task_p = super().task() task_p = configure_gpt3_task(self, task_p) @@ -263,7 +270,7 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): task_p.train.num_train_steps = self.MAX_STEPS model_p = task_p.model - + ### compute layernorm reductions in fp32. Needed for stable training on GPUs stacked_p = model_p.lm_tpl.stacked_transformer_tpl if stacked_p.cls == layers.PipelinedTransformer: @@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): transformer_layer_p.ln_tpl.reductions_in_fp32 = True transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True - + model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) model_p.lm_tpl.softmax_tpl.params_init = softmax_init - + model_p.apply_eval_sample_weights = True - + ## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0 stacked_p.dropout_prob = 0.0 stacked_p.input_dropout_prob = self.DROPOUT_PROB @@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset): if pp > 1: @experiment_registry.register class Synthetic126MCI(GPT126MPP, SyntheticDataset): - + ICI_MESH_SHAPE = [pp, dp, fsdp, tp] DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1] MICROBATCH_SIZE = 2 NUM_STAGES = pp PERCORE_BATCH_SIZE = percore_batch_size FRPOP_DTYPE = dtype - + def task(self): task_p = super().task() task_p.train.always_use_train_for_model_init=False @@ -333,7 +340,7 @@ if pp > 1: else: @experiment_registry.register class Synthetic126MCI(Synthetic126M): - + ICI_MESH_SHAPE = [dp, fsdp, tp] DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1] PERCORE_BATCH_SIZE = percore_batch_size @@ -343,7 +350,7 @@ else: ## disable eval EVAL_INTERVAL_STEPS = 100000 - + def task(self): task_p = super().task() @@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model} +if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True" +fi + if [[ ${MODEL_TYPE} == "126M" ]]; then CONFIG=ci_configs.Synthetic126MCI elif [[ ${MODEL_TYPE} == "5B" ]]; then