Skip to content

Commit

Permalink
Add an option to test-pax.sh to enable XLA cuDNN flash attention (#1045)
Browse files Browse the repository at this point in the history
Provide an option to run XLA cuDNN flash attention as an alternative to
TE cuDNN flash attention.
  • Loading branch information
Cjkkkk authored Sep 18, 2024
1 parent f116054 commit 056a3b0
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions .github/container/test-pax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ usage() {
echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py"
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4."
echo " --dtype Batch size, defaults to bfloat16."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-cudnn-fa If set, will use cudnn fa."
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
echo " --disable-fused-attn Whether disable TE fused attention."
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
Expand All @@ -26,13 +27,13 @@ usage() {
echo " --data-parallel Data parallelism to use. Defaults to 1."
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " -n, --nodes Number of nodes."
echo " -h, --help Print usage."
exit $1
}

args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
if [[ $? -ne 0 ]]; then
exit $1
fi
Expand All @@ -50,6 +51,7 @@ TP=1
PP=1
NODES=1
ENABLE_TE=0
ENABLE_CUDNN_FA=0
MODEL_TYPE=126M
NVTE_FUSED_ATTN=1
DROPOUT=0
Expand All @@ -75,6 +77,10 @@ while [ : ]; do
ENABLE_TE=1
shift 1
;;
--enable-cudnn-fa)
ENABLE_CUDNN_FA=1
shift 1
;;
--enable-dropout)
DROPOUT='0.1'
shift 1
Expand Down Expand Up @@ -128,7 +134,7 @@ while [ : ]; do
;;
--)
shift;
break
break
;;
*)
echo "UNKNOWN OPTION $1"
Expand All @@ -149,6 +155,7 @@ print_var NGPUS
print_var OUTPUT
print_var MULTIPROCESS
print_var ENABLE_TE
print_var ENABLE_CUDNN_FA
print_var NVTE_FUSED_ATTN
print_var EVALUATE
print_var DROPOUT
Expand Down Expand Up @@ -196,10 +203,10 @@ if dcn_factor > 1:
if dp % dcn_factor == 0:
dcn_dp = dcn_factor
dp = int(dp / dcn_factor)
elif fsdp % dcn_factor == 0:
elif fsdp % dcn_factor == 0:
dcn_fsdp = dcn_factor
fsdp = int(fsdp / dcn_factor)
elif pp % dcn_factor == 0:
elif pp % dcn_factor == 0:
dcn_pp = dcn_factor
pp = int(pp / dcn_factor)
Expand All @@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
USE_REPEATED_LAYER = False
ICI_MESH_SHAPE = [64,1,1]
MAX_STEPS = 600000
MAX_SEQ_LEN = 2048
VOCAB_SIZE = 50304
PACKED_INPUT = True
PERCORE_BATCH_SIZE = 4
NUM_LAYERS = 12
NUM_HEADS = 12
MODEL_DIMS = 768
Expand All @@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
TRAINABLE_POSITION_EMB = True
TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN
USE_BIAS = True
LAYERNORM_EPSILON = 1e-5
ATTEN_LOGIT_CAP = -1.0
INIT_STD = 0.023
SOFTMAX_INIT_STD = 0.023
ACTIVATION_CLS = layers.GELU
## optimizer-related
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.95
Expand All @@ -255,15 +262,15 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
## disable eval to avoid including eval
## in steps/sec calculation
EVAL_INTERVAL_STEPS = 100000
def task(self):
task_p = super().task()
task_p = configure_gpt3_task(self, task_p)
task_p.train.num_train_steps = self.MAX_STEPS
model_p = task_p.model
### compute layernorm reductions in fp32. Needed for stable training on GPUs
stacked_p = model_p.lm_tpl.stacked_transformer_tpl
if stacked_p.cls == layers.PipelinedTransformer:
Expand All @@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True
model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
model_p.lm_tpl.softmax_tpl.params_init = softmax_init
model_p.apply_eval_sample_weights = True
## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0
stacked_p.dropout_prob = 0.0
stacked_p.input_dropout_prob = self.DROPOUT_PROB
Expand Down Expand Up @@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset):
if pp > 1:
@experiment_registry.register
class Synthetic126MCI(GPT126MPP, SyntheticDataset):
ICI_MESH_SHAPE = [pp, dp, fsdp, tp]
DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1]
MICROBATCH_SIZE = 2
NUM_STAGES = pp
PERCORE_BATCH_SIZE = percore_batch_size
FRPOP_DTYPE = dtype
def task(self):
task_p = super().task()
task_p.train.always_use_train_for_model_init=False
Expand All @@ -333,7 +340,7 @@ if pp > 1:
else:
@experiment_registry.register
class Synthetic126MCI(Synthetic126M):
ICI_MESH_SHAPE = [dp, fsdp, tp]
DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
PERCORE_BATCH_SIZE = percore_batch_size
Expand All @@ -343,7 +350,7 @@ else:
## disable eval
EVAL_INTERVAL_STEPS = 100000
def task(self):
task_p = super().task()
Expand Down Expand Up @@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE
export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN
export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model}

if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then
ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True"
fi

if [[ ${MODEL_TYPE} == "126M" ]]; then
CONFIG=ci_configs.Synthetic126MCI
elif [[ ${MODEL_TYPE} == "5B" ]]; then
Expand Down

0 comments on commit 056a3b0

Please sign in to comment.