From 1a3febb377a5a0d7ff3344b3d7844ae6f56dacc7 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 26 Sep 2024 16:05:35 -0700 Subject: [PATCH] Model XLA Flags (#1052) Moves XLA flags from model CI into their own files that can be sourced. Each file can be sourced and will print what it sets. Some files source other files, which was intentional to avoid introducing sim-links into the repo, which can sometimes have platform issues (like on windows). --------- Signed-off-by: Terry Kong --- .../maxtext/xla_flags/llama2-7b-1N8G.env | 24 ++++++++++++++++++ .../rosetta/projects/pax/xla_flags/common.env | 13 ++++++++++ .../projects/pax/xla_flags/glam-126m64e.env | 3 +++ .../projects/pax/xla_flags/glam-64b64e.env | 3 +++ .../projects/pax/xla_flags/gpt-126m.env | 14 +++++++++++ .../projects/pax/xla_flags/gpt-175b.env | 3 +++ .../rosetta/projects/pax/xla_flags/gpt-5b.env | 3 +++ .../projects/pax/xla_flags/grok-proxy.env | 25 +++++++++++++++++++ .../projects/pax/xla_flags/llama-70b.env | 3 +++ .../projects/pax/xla_flags/llama-7b-lora.env | 4 +++ .../projects/pax/xla_flags/llama-7b.env | 4 +++ rosetta/rosetta/projects/t5x/xla_flags/t5.env | 4 +++ .../vit/xla_flags/vit-base-highgbs.env | 4 +++ .../projects/vit/xla_flags/vit-base.env | 4 +++ 14 files changed, 111 insertions(+) create mode 100644 rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/common.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/llama-70b.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/llama-7b.env create mode 100644 rosetta/rosetta/projects/t5x/xla_flags/t5.env create mode 100644 rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env create mode 100644 rosetta/rosetta/projects/vit/xla_flags/vit-base.env diff --git a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env new file mode 100644 index 000000000..d999f5b5e --- /dev/null +++ b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env @@ -0,0 +1,24 @@ +set -x +NUM_NODES=1 +NUM_GPUS=8 +THRESHOLD_BYTES=1073741824 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_gpu_enable_triton_gemm=false \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset NUM_NODES NUM_GPUS THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/common.env b/rosetta/rosetta/projects/pax/xla_flags/common.env new file mode 100644 index 000000000..26c819143 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/common.env @@ -0,0 +1,13 @@ +set -x +THRESHOLD_BYTES=51200 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env new file mode 100644 index 000000000..e5b97b466 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env @@ -0,0 +1,14 @@ +set -x +THRESHOLD_BYTES=33554432 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_cudnn_fmha=false \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env new file mode 100644 index 000000000..e48b76dcf --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -0,0 +1,25 @@ +set -x +ALL_REDUCE_THRESHOLD_BYTES=3221225472 +ALL_GATHER_THRESHOLD_BYTES=3221225472 +REDUCE_SCATTER_THRESHOLD_BYTES=402653184 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + --xla_gpu_enable_custom_fusions=true + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env new file mode 100644 index 000000000..d1568e92c --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/t5x/xla_flags/t5.env b/rosetta/rosetta/projects/t5x/xla_flags/t5.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/t5x/xla_flags/t5.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env new file mode 100644 index 000000000..45140ed88 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.75 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env new file mode 100644 index 000000000..882c9e9e8 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +set +x