From de7e105fa2acf7e3f2c01aa991fe9898f0a68bf6 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Tue, 10 Dec 2024 04:04:36 +0000 Subject: [PATCH] Add Pathways Support to Benchmark Runner --- benchmarks/benchmark_runner.py | 66 ++++++- benchmarks/maxtext_trillium_model_configs.py | 190 +++++++++++++++++++ benchmarks/maxtext_xpk_runner.py | 119 ++++++++---- 3 files changed, 338 insertions(+), 37 deletions(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 8e629fe55..256a3e1b2 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -28,6 +28,7 @@ from maxtext_xpk_runner import BenchmarkRunner from maxtext_xpk_runner import HWConfig from maxtext_xpk_runner import SWconfig +from maxtext_xpk_runner import PathwaysConfig from maxtext_xpk_runner import xpk_benchmark_runner from maxtext_xpk_runner import XpkConfig @@ -86,6 +87,11 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_7b_4096', 'llama2_70b_4096', 'llama2_70b_4096_real_data', + 'llama2_70b_4096_pw_long_run', + 'llama2_70b_4096_real_data_pw_long_run', + 'llama2_70b_4096_pw_rd_tfds', + 'llama2_70b_4096_synthetic_pw_lr', + 'llama2_70b_4096_synthetic', 'llama3_70b_8192', 'llama3_1_405b_8192_fsdp_dcn', 'mixtral_8x7b_dropped', @@ -103,6 +109,11 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_7b_4096 ' 'llama2_70b_4096 ' 'llama2_70b_4096_real_data ' + 'llama2_70b_4096_pw_long_run ' + 'llama2_70b_4096_real_data_pw_long_run ' + 'llama2_70b_4096_pw_rd_tfds ' + 'llama2_70b_4096_synthetic_pw_lr ' + 'llama2_70b_4096_synthetic ' 'llama3_1_405b_8192_fsdp_dcn ' 'mixtral_8x7b_dropped ' 'mixtral_8x7b_dropped_int8 ' @@ -124,6 +135,51 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): default='maxtext_base_image', help='version of base docker image to be benchmarked command.', ) + custom_parser.add_argument( + '--pathways_server_image', + type=str, + default=( + 'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest' + ), + help='version of pathways server image to be benchmarked command.', + ) + custom_parser.add_argument( + '--pathways_proxy_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest', + help='version of pathways proxy image to be benchmarked command.', + ) + custom_parser.add_argument( + '--pathways_runner_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest', + help='version of pathways runner image to be benchmarked command.', + ) + custom_parser.add_argument( + '--use_pathways', + type=bool, + default=False, + help='whether to use pathways or not.', + ) + custom_parser.add_argument( + '--xpk_path', + type=str, + default='~/xpk', + help='path to xpk dir.', + ) + custom_parser.add_argument( + '--priority', + type=str, + default='medium', + help='Priority the XPK workload should run with.', + ) + custom_parser.add_argument( + '--max_restarts', + type=int, + default=0, + help='Number of restarts to attempt.', + ) + def main() -> None: parser = argparse.ArgumentParser( @@ -139,11 +195,19 @@ def main() -> None: num_slices=options.num_slices, device_type=options.device_type, base_output_directory=options.base_output_directory, + priority=options.priority, + max_restarts=options.max_restarts, ) v6e_env_configs = SWconfig( base_docker_image=options.base_docker_image, libtpu_version=options.libtpu_version, + pathways_config=PathwaysConfig( + use_pathways=options.use_pathways, + server_image=options.pathways_server_image, + proxy_image=options.pathways_proxy_image, + runner_image=options.pathways_runner_image, + ), ) v6e_256_configs = HWConfig( @@ -159,7 +223,7 @@ def main() -> None: hardware_config=v6e_256_configs, ) - xpk_benchmark_runner(cluster_config, [model_runner]) + xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path) if __name__ == '__main__': diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 952d803ca..16c6d9764 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -291,6 +291,46 @@ class MaxTextModel: ), ) + +llama2_70b_4096_real_data_pw_long_run = MaxTextModel( + model_name="llama2-70b-4096-rd-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "reuse_example_batch": 0, + "profiler": "xplane", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "tfds", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + # ici_fsdp_transpose_parallelism gives one TFLOP better performance. llama2_70b_4096 = MaxTextModel( model_name="llama2-70b-4096", @@ -319,6 +359,151 @@ class MaxTextModel: + xla_flags_library.CF_FOR_ALL_GATHER ), ) +llama2_70b_4096_synthetic = MaxTextModel( + model_name="llama2_70b_4096_synthetic", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + +llama2_70b_4096_synthetic_pw_lr = MaxTextModel( + model_name="llama2_70b_4096_synthetic_pw_lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + # "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + +llama2_70b_4096_pw_long_run = MaxTextModel( + model_name="llama2-70b-4096-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + +llama2_70b_4096_pw_rd_tfds = MaxTextModel( + model_name="llama2_70b_4096_pw_rd_tfds", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + llama3_8b_8192 = MaxTextModel( model_name="llama3-8b-8192", @@ -695,9 +880,14 @@ class MaxTextModel: gpt_3_175b, llama2_7b_4096, llama2_70b_4096, + llama2_70b_4096_pw_long_run, llama2_70b_4096_real_data, + llama2_70b_4096_real_data_pw_long_run, + llama2_70b_4096_pw_rd_tfds, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet + llama2_70b_4096_synthetic_pw_lr, + llama2_70b_4096_synthetic, llama3_1_405b_8192_fsdp_dcn, llama3_1_8b_8192, llama3_1_70b_8192, diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 8c728da3e..00773e5e1 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -49,6 +49,16 @@ class XpkConfig: num_slices: str device_type: str base_output_directory: str + priority: str + max_restarts: int + + +@dataclasses.dataclass +class PathwaysConfig: + use_pathways: bool + server_image: str + proxy_image: str + runner_image: str @dataclasses.dataclass @@ -61,6 +71,7 @@ class HWConfig: class SWconfig: libtpu_version: str base_docker_image: str + pathways_config: PathwaysConfig @dataclasses.dataclass @@ -257,21 +268,25 @@ def run_command_with_updates(command, task, verbose=True) -> int: def build_user_command( + name: str, model: model_configs.MaxTextModel, num_slices: int, num_steps: int, libtpu_type: LibTpuType, libtpu_date: str, cluster_config: XpkConfig, - base_output_directory: str, + base_output_directory: str, buffer_size: int, + pathways_config: PathwaysConfig = None, ): config_tuning_params = '' for key, value in model.tuning_params.items(): config_tuning_params += f'{key}={value} ' install_libtpu_cmd = '' - if libtpu_type == LibTpuType.NIGHTLY: + if pathways_config.use_pathways: + pass + elif libtpu_type == LibTpuType.NIGHTLY: install_libtpu_cmd += ( f' pip install libtpu-nightly==0.1.dev{libtpu_date} -f' ' https://storage.googleapis.com/libtpu-releases/index.html &&' @@ -288,35 +303,30 @@ def build_user_command( # model.xla_flags += ' --grpc_enable_rpc_receive_coalescing=true' # model.xla_flags += ' --grpc_experiments=tcp_rcv_lowat' + # Use single quotes for LIBTPU_INIT_ARGS and escape inner single quotes libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'" - - return ( - # f'python3 -m pip install google-cloud-aiplatform==v1.61.0 &&' - # f'pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html &&' - # f' pip install https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.4.27.dev20240501-cp310-cp310-manylinux2014_x86_64.whl &&' - # f' pip install git+https://github.com/jax-ml/jax.git@57bfe81260545556ec22509347f7ced112496200 &&' - f' {install_libtpu_cmd}' - # f' mv libtpu.so /lib/ &&' - # f' export TPU_LIBRARY_PATH=$PWD/libtpu.so &&' - f' echo {libtpu_flags} &&' - # f' echo {model.tuning_params["sa_block_q"]}-q-dq-{model.tuning_params["sa_block_q_dq"]}-q-dkv-{model.tuning_params["sa_block_q_dkv"]} &&' - # f' echo {model.tuning_params["ici_fsdp_parallelism"]} {model.tuning_params["ici_tensor_parallelism"]} &&' - f' export JAX_PLATFORMS=tpu,cpu &&' - # f' export JAX_DEBUG_NANS=True &&' - # f' export TPU_MEGACORE=megachip_tccontrol &&' - # f' echo TPU MEGACORE: $TPU_MEGACORE &&' - f' export TPU_PREMAPPED_BUFFER_SIZE={buffer_size} &&' - f' echo {buffer_size} &&' - f' export ENABLE_PJRT_COMPATIBILITY=true &&' - f' export {libtpu_flags} && ' - ' python3 MaxText/train.py MaxText/configs/base.yml' - f' {config_tuning_params} steps={num_steps} enable_checkpointing=false' - f' model_name={model.model_type}' - f' base_output_directory={base_output_directory}' - f' use_vertex_tensorboard=false' - ' vertex_tensorboard_project="" vertex_tensorboard_region=""' - f' run_name="{model.model_name}-{num_slices}-{libtpu_date}"' - ) + jax_platforms = 'proxy' if pathways_config.use_pathways else 'tpu,cpu' + vertex_tensorboard = 'vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' + + # Construct the command string with proper formatting and line continuations + command = ' '.join([ + f'{install_libtpu_cmd}', + f'echo {libtpu_flags} &&' if not pathways_config.use_pathways else '', + f'export {libtpu_flags} &&' if not pathways_config.use_pathways else '', + 'export ENABLE_PATHWAYS_PERSISTENCE=1 &&', + f'export JAX_PLATFORMS={jax_platforms} &&', + f'export TPU_PREMAPPED_BUFFER_SIZE={buffer_size} &&', + f'echo {buffer_size} &&', + 'export ENABLE_PJRT_COMPATIBILITY=true &&', + 'python3 MaxText/train.py MaxText/configs/base.yml', + f'{config_tuning_params}', + f'model_name={model.model_type}', + f'base_output_directory={base_output_directory}', + 'use_vertex_tensorboard=false', + f'{vertex_tensorboard}', + f'run_name={name}' + ]) + return command def generate_xpk_workload_cmd( @@ -327,11 +337,12 @@ def generate_xpk_workload_cmd( libtpu_version: str, base_output_directory: str, buffer_size: int, + xpk_path: str = '~/xpk', + pathways_config: PathwaysConfig = None, ): """Generates a command to run a maxstar model on XPK.""" num_steps = 20 time.localtime() - test_purpose_name = f'maxstar-benchmarks-{model.model_name}-{libtpu_version}' N = 3 temp_post_fix = ''.join( random.choice(string.ascii_lowercase + string.digits) for _ in range(N) @@ -340,7 +351,14 @@ def generate_xpk_workload_cmd( name = ( f"{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}" ) + if pathways_config.use_pathways: + # Pathways run names are long and need to be shortened. + name = ( + f"pw-{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{temp_post_fix}" + ) + user_command = build_user_command( + name, model, num_slices, num_steps, @@ -349,6 +367,7 @@ def generate_xpk_workload_cmd( cluster_config, base_output_directory, buffer_size, + pathways_config, ) additional_flags = '' @@ -361,23 +380,41 @@ def generate_xpk_workload_cmd( ' https://raw.githubusercontent.com/GoogleCloudPlatform/ai-on-gke/9ff340f07f70be0130454f9e7238551587242b75/scripts/network-setup/v6e-network-optimization.yaml' ) + # pathways-related flags + pathways_specific_flags = '' + docker_image_flag = f'--base-docker-image="{BASE_DOCKER_IMAGE}"' + if pathways_config.use_pathways: + pathways_specific_flags = ( + '--use-pathways' + f' --server-image={pathways_config.server_image}' + f' --proxy-server-image={pathways_config.proxy_image}' + ' --termination-grace-period-seconds=300' + f' --pathways-gcs-location={base_output_directory}' + f' --restart-on-user-code-failure' + f' --debug-dump-gcs={base_output_directory}' + ) + docker_image_flag = ( + f'--docker-image={pathways_config.runner_image}' + ) + print(f'User command: {user_command}') return ( ( # f'{perf_optimzation_dcn} &&' - 'python3 ~/xpk/xpk.py workload create' + f'python3 {xpk_path}/xpk.py workload create' + f' {pathways_specific_flags}' f' --cluster={cluster_config.cluster_name}' f' --project={cluster_config.project}' f' --zone={cluster_config.zone}' f' --device-type={cluster_config.device_type}' f' --num-slices={cluster_config.num_slices}' f' --command="{user_command}"' - f' --base-docker-image="{BASE_DOCKER_IMAGE}"' + f' {docker_image_flag}' ' --enable-debug-logs' f' --workload={name}' - ' --priority=medium' + f' --priority={cluster_config.priority}' + f' --max-restarts={cluster_config.max_restarts}' # ' --use-vertex-tensorboard' - # f' --experiment-name={test_purpose_name}' f' {additional_flags}' ), name, @@ -406,7 +443,11 @@ def run_xpk_workload( return run_command_with_updates(command, 'Run XPK workload', cluster_config) -def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner]): +def xpk_benchmark_runner( + cluster_config: XpkConfig, + benchmarks: list[BenchmarkRunner], + xpk_path: str = '~/xpk', +): xpk_workload_names = [] xpk_workload_cmds = [] for benchmark in benchmarks: @@ -418,8 +459,14 @@ def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRu libtpu_version=benchmark.software_config.libtpu_version, base_output_directory=cluster_config.base_output_directory, buffer_size=4294967296, + xpk_path=xpk_path, + pathways_config=benchmark.software_config.pathways_config, ) + + print(f"name of the workload is: {name}") xpk_workload_names.append(name) + + print(f"XPK command to be used is: {command}") xpk_workload_cmds.append(command) returncodes = run_commands(