Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pathways Support to Benchmark Runner #1094

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from maxtext_xpk_runner import BenchmarkRunner
from maxtext_xpk_runner import HWConfig
from maxtext_xpk_runner import SWconfig
from maxtext_xpk_runner import PathwaysConfig
from maxtext_xpk_runner import xpk_benchmark_runner
from maxtext_xpk_runner import XpkConfig

Expand Down Expand Up @@ -86,6 +87,11 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_7b_4096',
'llama2_70b_4096',
'llama2_70b_4096_real_data',
'llama2_70b_4096_pw_long_run',
'llama2_70b_4096_real_data_pw_long_run',
'llama2_70b_4096_pw_rd_tfds',
'llama2_70b_4096_synthetic_pw_lr',
'llama2_70b_4096_synthetic',
'llama3_70b_8192',
'llama3_1_405b_8192_fsdp_dcn',
'mixtral_8x7b_dropped',
Expand All @@ -103,6 +109,11 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_7b_4096 '
'llama2_70b_4096 '
'llama2_70b_4096_real_data '
'llama2_70b_4096_pw_long_run '
'llama2_70b_4096_real_data_pw_long_run '
'llama2_70b_4096_pw_rd_tfds '
'llama2_70b_4096_synthetic_pw_lr '
'llama2_70b_4096_synthetic '
'llama3_1_405b_8192_fsdp_dcn '
'mixtral_8x7b_dropped '
'mixtral_8x7b_dropped_int8 '
Expand All @@ -124,6 +135,51 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
default='maxtext_base_image',
help='version of base docker image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_server_image',
type=str,
default=(
'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest'
),
help='version of pathways server image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_proxy_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest',
help='version of pathways proxy image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_runner_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest',
help='version of pathways runner image to be benchmarked command.',
)
custom_parser.add_argument(
'--use_pathways',
type=bool,
default=False,
help='whether to use pathways or not.',
)
custom_parser.add_argument(
'--xpk_path',
type=str,
default='~/xpk',
help='path to xpk dir.',
)
custom_parser.add_argument(
SujeethJinesh marked this conversation as resolved.
Show resolved Hide resolved
'--priority',
type=str,
default='medium',
help='Priority the XPK workload should run with.',
)
custom_parser.add_argument(
'--max_restarts',
type=int,
default=0,
help='Number of restarts to attempt.',
)


def main() -> None:
parser = argparse.ArgumentParser(
Expand All @@ -139,11 +195,19 @@ def main() -> None:
num_slices=options.num_slices,
device_type=options.device_type,
base_output_directory=options.base_output_directory,
priority=options.priority,
max_restarts=options.max_restarts,
)

v6e_env_configs = SWconfig(
base_docker_image=options.base_docker_image,
libtpu_version=options.libtpu_version,
pathways_config=PathwaysConfig(
use_pathways=options.use_pathways,
server_image=options.pathways_server_image,
proxy_image=options.pathways_proxy_image,
runner_image=options.pathways_runner_image,
),
)

v6e_256_configs = HWConfig(
Expand All @@ -159,7 +223,7 @@ def main() -> None:
hardware_config=v6e_256_configs,
)

xpk_benchmark_runner(cluster_config, [model_runner])
xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path)
SujeethJinesh marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
Expand Down
190 changes: 190 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,46 @@ class MaxTextModel:
),
)


llama2_70b_4096_real_data_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-rd-pw-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"reuse_example_batch": 0,
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": "assets/tokenizer.llama2",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000000,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

# ici_fsdp_transpose_parallelism gives one TFLOP better performance.
llama2_70b_4096 = MaxTextModel(
model_name="llama2-70b-4096",
Expand Down Expand Up @@ -319,6 +359,151 @@ class MaxTextModel:
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)
llama2_70b_4096_synthetic = MaxTextModel(
model_name="llama2_70b_4096_synthetic",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama2_70b_4096_synthetic_pw_lr = MaxTextModel(
model_name="llama2_70b_4096_synthetic_pw_lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
# "enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000000,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama2_70b_4096_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-pw-lr",
SujeethJinesh marked this conversation as resolved.
Show resolved Hide resolved
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000000,

# Additional tuning params for pathways long running test.
SujeethJinesh marked this conversation as resolved.
Show resolved Hide resolved
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama2_70b_4096_pw_rd_tfds = MaxTextModel(
model_name="llama2_70b_4096_pw_rd_tfds",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://trillium-storage-datasets-sr",
"enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)


llama3_8b_8192 = MaxTextModel(
model_name="llama3-8b-8192",
Expand Down Expand Up @@ -695,9 +880,14 @@ class MaxTextModel:
gpt_3_175b,
llama2_7b_4096,
llama2_70b_4096,
llama2_70b_4096_pw_long_run,
llama2_70b_4096_real_data,
llama2_70b_4096_real_data_pw_long_run,
llama2_70b_4096_pw_rd_tfds,
llama3_8b_8192, # Not Optimizied yet
llama3_70b_8192, # Not Optimizied yet
llama2_70b_4096_synthetic_pw_lr,
llama2_70b_4096_synthetic,
llama3_1_405b_8192_fsdp_dcn,
llama3_1_8b_8192,
llama3_1_70b_8192,
Expand Down
Loading
Loading