Skip to content

Commit

Permalink
Updating 24.04-devel branch with maxtext perf (#789)
Browse files Browse the repository at this point in the history
Co-authored-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
  • Loading branch information
terrykong and kocchop authored May 2, 2024
1 parent b024ad5 commit 00f755e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
15 changes: 13 additions & 2 deletions rosetta/rosetta/projects/maxtext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
[MaxText](https://github.com/google/maxtext) is high performance scalable LLM framework by Google written in Python and JAX. We support the upstream maxtext and have containers that can support the MaxText main branch out-of-the-box. While training, we strongly recommend to use propoer XLA flags pointed below.

## Hardware and Software Specifications
Functionality and performance have been validated on NVIDIA DGX H100 (8x H100 80G) nodes; We provide both singlenode and multinode pre-training support. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU.
Functionality and performance have been validated on NVIDIA DGX H100 (8x H100 80G) nodes; please refer to the [Configs](#configs) section below for some initial configs and performance numbers. We will continue to populate it with more models and configs. We provide both singlenode and multinode pre-training support. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU.

The [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) is required to run the subsequent commands with GPU support. Ensure the NVIDIA Container Toolkit is installed before proceeding.

Expand Down Expand Up @@ -45,7 +45,7 @@ python3 MaxText/train.py \
dataset_path=local \
dataset_type=synthetic \
attention=dot_product \
hardware=gpu
hardware=gpu \
run_name=${YOUR_JOB_NAME}
```

Expand Down Expand Up @@ -87,6 +87,17 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"
```

# Configs
### LLaMA
We have run some intial performance and functionality tests with [LLaMA2-7B](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) model. The table below shows the current performance of the given configs. Experiments were run using NVIDIA DGX H100 80G nodes.

| Size | GPU | Precision | Sequence Length | #GPUs | BS / GPU | DP | FSDP | TP | GBS | Attention | Remat Policy | Scan | Step Time (s) | Sequences/Sec |
| ---- | ------------ | --------- | --------------- | ----- | -------- | -- | ---- | -- | --- | --------- | ------------ | ---- | ------------- | ------------- |
| 7B | H100 80G SXM | BF16 | 4096 | 8 | 2 | 1 | 8 | 1 | 16 | Flash | minimal_flash| Off | 0.721 | 22.19 |

Please refer to the [example run script](scripts/example_slurm.sub) for more details. We will continue to add more models and associated performance metrics.

# Notes
1. The only changes we need to support multiprocessing is to pin tensorflow and tensorflow-text to 2.13.0 version.
2. In order to remove extra copies introduced by DUS (dynamic update slice) when used in conjunction with custom NVIDIA kernels (like cuBLAS for GEMMs), the `--xla_gpu_enable_custom_fusions` and `--xla_gpu_enable_address_computation_fusion` flags were introduced. However, the current XLA has some limitation and sometimes using these flags lead to error. So, in this release, it is advised to turn off these two flags:
Expand Down
2 changes: 1 addition & 1 deletion rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ echo "*******STARTING********" \
base_output_directory=local_train \
dataset_path=local \
dataset_type=synthetic \
hardware=gpu_multiprocess
hardware=gpu_multiprocess \
run_name=$RUN_NAME
EOF

Expand Down

0 comments on commit 00f755e

Please sign in to comment.