-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FMHA PAXML test #830
base: main
Are you sure you want to change the base?
Add FMHA PAXML test #830
Changes from all commits
a9a828b
69f9db9
35565ce
42566a6
495ae32
23a19d8
4be1f40
cb9d7d5
f44cdef
da69dbd
a6622c8
b67229b
f7618bf
ab22fcc
dbb999d
e0f4b4e
1898b32
c1ff8ae
0ef811a
ca6e2e9
7084812
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,12 +17,13 @@ usage() { | |
echo " --dtype Batch size, defaults to bfloat16." | ||
echo " --enable-te If set, will run with env var ENABLE_TE=1." | ||
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." | ||
echo " --disable-fused-attn Whether disable TE fused attention." | ||
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M" | ||
echo " --evaluate Whether to test evaluation rather than training." | ||
echo " -s, --steps Number of steps to run, defaults to 500." | ||
echo " --multiprocess Enable the multiprocess GPU mode." | ||
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified." | ||
echo " --save-hlo {0, 1} 1 to save the dumped hlo, 0 to remove the hlo dumped folder" | ||
echo " --enable-fmha {0, 1} 1 to enable fmha testing, 0 to run test without fmha; default is 0" | ||
echo " --data-parallel Data parallelism to use. Defaults to 1." | ||
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1." | ||
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1." | ||
|
@@ -32,7 +33,8 @@ usage() { | |
exit $1 | ||
} | ||
|
||
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") | ||
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,model-type:,enable-fmha:,evaluate,steps:,help,multiprocess,output:,save-hlo:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") | ||
|
||
if [[ $? -ne 0 ]]; then | ||
exit $1 | ||
fi | ||
|
@@ -55,6 +57,8 @@ NVTE_FUSED_ATTN=1 | |
DROPOUT=0 | ||
EVALUATE=0 | ||
ADDITIONAL_ARGS="" | ||
ENABLE_FMHA=${ENABLE_FMHA:-1} | ||
SAVE_HLO=${SAVE_HLO:-0} | ||
|
||
eval set -- "$args" | ||
while [ : ]; do | ||
|
@@ -75,14 +79,15 @@ while [ : ]; do | |
ENABLE_TE=1 | ||
shift 1 | ||
;; | ||
--enable-fmha) | ||
ENABLE_FMHA="$2" | ||
NVTE_FUSED_ATTN="$2" | ||
shift 2 | ||
;; | ||
--enable-dropout) | ||
DROPOUT='0.1' | ||
shift 1 | ||
;; | ||
--disable-fused-attn) | ||
NVTE_FUSED_ATTN=0 | ||
shift 1 | ||
;; | ||
--model-type) | ||
MODEL_TYPE=$2 | ||
shift 2 | ||
|
@@ -103,6 +108,10 @@ while [ : ]; do | |
OUTPUT=$2 | ||
shift 2 | ||
;; | ||
--save-hlo) | ||
SAVE_HLO="$2" | ||
shift 2 | ||
;; | ||
--data-parallel) | ||
DP="$2" | ||
shift 2 | ||
|
@@ -136,6 +145,21 @@ while [ : ]; do | |
esac | ||
done | ||
|
||
# Set hlo dump folder after output folder is set. | ||
HLO_DIR=${OUTPUT}/hlo | ||
export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please explain logic here: is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dumping the hlo is enabled by default in BASE_XLA_FLAGS, and BASE_XLA_FLAGS are appended to XLA_FLAGS env var. if user wants to test fmha then BASE_XLA_FLAGS_FMHA is added and appended to XLA_FLAGS. The idea is to preserve the env var XLA_FLAGS before execution of this script. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, let me clarify my question:
Meaning, that if Is that expected behaviour? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And why do you There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mechanism was added as per the review comment of same PR for t5x: #442 (comment) |
||
export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}" | ||
echo "HLO will be dumped in ${HLO_DIR} dir." | ||
|
||
## Setting the env variables for FMHA | ||
if [[ "$ENABLE_FMHA" -eq "1" ]]; then | ||
echo "Setting XLA FMHA Flags"; | ||
export BASE_XLA_FLAGS_FMHA="${BASE_XLA_FLAGS_FMHA:---xla_gpu_fused_attention_use_cudnn_rng=true --xla_gpu_enable_cudnn_fmha=true}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Save here as above |
||
export XLA_FLAGS="${BASE_XLA_FLAGS_FMHA} ${XLA_FLAGS:-}" | ||
fi | ||
|
||
echo "XLA FLAGS: $XLA_FLAGS" | ||
|
||
# # Set derived variables | ||
|
||
GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') | ||
|
@@ -149,8 +173,10 @@ print_var NGPUS | |
print_var OUTPUT | ||
print_var MULTIPROCESS | ||
print_var ENABLE_TE | ||
print_var ENABLE_FMHA | ||
print_var NVTE_FUSED_ATTN | ||
print_var EVALUATE | ||
print_var SAVE_HLO | ||
print_var DROPOUT | ||
print_var DP | ||
print_var FSDP | ||
|
@@ -422,5 +448,25 @@ else | |
$([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu) | ||
fi | ||
|
||
echo "Checking for FMHA instructions in HLO!" | ||
|
||
if [[ "$ENABLE_FMHA" -eq "1" ]]; then | ||
## Check if fmha instructions are present in the HLO dumped file or not. | ||
fmha_regex="fmha[-bmm]?[-scale]?[-bias]?[-mask]?[-softmax]?[-dropout]?[-bmm]?[-backward]?*" | ||
result=$(grep -irlnE "$fmha_regex" "${HLO_DIR}/"*.txt) | ||
|
||
if [ -z "$result" ]; then | ||
echo "E: No FMHA instructions were found in the hlo files!" | ||
exit 1 | ||
else | ||
echo -e "Found FMHA instructions in the following HLO files: \n $result" | ||
fi | ||
fi | ||
|
||
if [[ $SAVE_HLO -eq 0 ]]; then | ||
rm -rf $HLO_DIR | ||
echo "Removed dumped HLO directory!" | ||
fi | ||
|
||
set +x | ||
echo "Output at ${OUTPUT}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: default doesn't match below