diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index ce3e05891..28b69d5c8 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -2,7 +2,7 @@ [MaxText](https://github.com/google/maxtext) is high performance scalable LLM framework by Google written in Python and JAX. We support the upstream maxtext and have containers that can support the MaxText main branch out-of-the-box. While training, we strongly recommend to use propoer XLA flags pointed below. ## Hardware and Software Specifications -Functionality and performance have been validated on NVIDIA DGX H100 (8x H100 80G) nodes; We provide both singlenode and multinode pre-training support. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU. +Functionality and performance have been validated on NVIDIA DGX H100 (8x H100 80G) nodes; please refer to the [Configs](#configs) section below for some initial configs and performance numbers. We will continue to populate it with more models and configs. We provide both singlenode and multinode pre-training support. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU. The [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) is required to run the subsequent commands with GPU support. Ensure the NVIDIA Container Toolkit is installed before proceeding. @@ -45,7 +45,7 @@ python3 MaxText/train.py \ dataset_path=local \ dataset_type=synthetic \ attention=dot_product \ - hardware=gpu + hardware=gpu \ run_name=${YOUR_JOB_NAME} ``` @@ -87,6 +87,17 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" ``` + +# Configs +### LLaMA +We have run some intial performance and functionality tests with [LLaMA2-7B](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) model. The table below shows the current performance of the given configs. Experiments were run using NVIDIA DGX H100 80G nodes. + +| Size | GPU | Precision | Sequence Length | #GPUs | BS / GPU | DP | FSDP | TP | GBS | Attention | Remat Policy | Scan | Step Time (s) | Sequences/Sec | +| ---- | ------------ | --------- | --------------- | ----- | -------- | -- | ---- | -- | --- | --------- | ------------ | ---- | ------------- | ------------- | +| 7B | H100 80G SXM | BF16 | 4096 | 8 | 2 | 1 | 8 | 1 | 16 | Flash | minimal_flash| Off | 0.721 | 22.19 | + +Please refer to the [example run script](scripts/example_slurm.sub) for more details. We will continue to add more models and associated performance metrics. + # Notes 1. The only changes we need to support multiprocessing is to pin tensorflow and tensorflow-text to 2.13.0 version. 2. In order to remove extra copies introduced by DUS (dynamic update slice) when used in conjunction with custom NVIDIA kernels (like cuBLAS for GEMMs), the `--xla_gpu_enable_custom_fusions` and `--xla_gpu_enable_address_computation_fusion` flags were introduced. However, the current XLA has some limitation and sometimes using these flags lead to error. So, in this release, it is advised to turn off these two flags: diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index 2980a3e44..b934b75f4 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -103,7 +103,7 @@ echo "*******STARTING********" \ base_output_directory=local_train \ dataset_path=local \ dataset_type=synthetic \ - hardware=gpu_multiprocess + hardware=gpu_multiprocess \ run_name=$RUN_NAME EOF