diff --git a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env new file mode 100644 index 000000000..d999f5b5e --- /dev/null +++ b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env @@ -0,0 +1,24 @@ +set -x +NUM_NODES=1 +NUM_GPUS=8 +THRESHOLD_BYTES=1073741824 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_gpu_enable_triton_gemm=false \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset NUM_NODES NUM_GPUS THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/common.env b/rosetta/rosetta/projects/pax/xla_flags/common.env new file mode 100644 index 000000000..26c819143 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/common.env @@ -0,0 +1,13 @@ +set -x +THRESHOLD_BYTES=51200 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env new file mode 100644 index 000000000..e5b97b466 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env @@ -0,0 +1,14 @@ +set -x +THRESHOLD_BYTES=33554432 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_cudnn_fmha=false \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env new file mode 100644 index 000000000..e48b76dcf --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -0,0 +1,25 @@ +set -x +ALL_REDUCE_THRESHOLD_BYTES=3221225472 +ALL_GATHER_THRESHOLD_BYTES=3221225472 +REDUCE_SCATTER_THRESHOLD_BYTES=402653184 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + --xla_gpu_enable_custom_fusions=true + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env new file mode 100644 index 000000000..d1568e92c --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/t5x/xla_flags/t5.env b/rosetta/rosetta/projects/t5x/xla_flags/t5.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/t5x/xla_flags/t5.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env new file mode 100644 index 000000000..45140ed88 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.75 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env new file mode 100644 index 000000000..882c9e9e8 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +set +x