Skip to content

Commit

Permalink
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into sujin…
Browse files Browse the repository at this point in the history
…esh/llama2_v6e_pw_long_running_test
  • Loading branch information
SujeethJinesh committed Jan 9, 2025
2 parents f3a3a92 + 5df52b1 commit 01798e3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
29 changes: 26 additions & 3 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager

# pylint: disable=too-many-positional-arguments

Expand Down Expand Up @@ -91,15 +92,15 @@ def create_orbax_emergency_checkpoint_manager(
persistent_save_interval_steps: int,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
):
"""Returns an emergency checkpoint."""
"""Returns an emergency checkpoint manager."""
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency checkpoint manager...")

options = emergency_checkpoint_manager.CheckpointManagerOptions(
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
)
emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
manager = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
global_mesh=global_mesh,
Expand All @@ -109,7 +110,29 @@ def create_orbax_emergency_checkpoint_manager(
)

max_logging.log("Emergency checkpoint manager created!")
return emergency_mngr
return manager


def create_orbax_emergency_replicator_checkpoint_manager(
local_checkpoint_dir: str,
save_interval_steps: int,
global_mesh: jax.sharding.Mesh,
):
"""Returns an emergency replicator checkpoint manager."""
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency replicator checkpoint manager...")

options = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManagerOptions(
save_interval_steps=save_interval_steps,
)
manager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager(
epath.Path(local_checkpoint_dir),
options,
global_mesh=global_mesh,
)

max_logging.log("Emergency replicator checkpoint manager created!")
return manager


def print_save_message(step, async_checkpointing):
Expand Down
27 changes: 17 additions & 10 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,16 +648,23 @@ def setup_mesh_and_model(config):
tx = optimizers.get_optimizer(config, learning_rate_schedule)
logger = checkpointing.setup_checkpoint_logger(config)
if config.enable_emergency_checkpoint:
abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager(
config.local_checkpoint_directory,
config.checkpoint_dir,
mesh,
abstract_state,
config.local_checkpoint_period,
config.checkpoint_period,
logger,
)
if config.use_replicator_service:
checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager(
config.local_checkpoint_directory,
config.local_checkpoint_period,
mesh,
)
else:
abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager(
config.local_checkpoint_directory,
config.checkpoint_dir,
mesh,
abstract_state,
config.local_checkpoint_period,
config.checkpoint_period,
logger,
)
else:
# TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend
use_ocdbt = config.checkpoint_storage_use_ocdbt
Expand Down
6 changes: 3 additions & 3 deletions constraints_gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ google-jetstream==0.2.2
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.65.0
grain-nightly==0.0.9
grain-nightly==0.0.10
grpc-google-iam-v1==0.13.1
grpcio==1.67.0
grpcio-status==1.48.2
Expand Down Expand Up @@ -123,7 +123,7 @@ opentelemetry-api==1.27.0
opt_einsum==3.4.0
optax==0.2.3
optree==0.13.0
orbax-checkpoint==0.6.4
orbax-checkpoint==0.10.3
packaging==24.1
pandas==2.2.3
parameterized==0.9.0
Expand Down Expand Up @@ -186,7 +186,7 @@ tensorflow-datasets==4.9.6
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-metadata==1.16.1
tensorflow-text==2.17.0
tensorstore==0.1.66
tensorstore==0.1.68
termcolor==2.5.0
tfds-nightly==4.9.2.dev202308090034
tiktoken==0.8.0
Expand Down
2 changes: 2 additions & 0 deletions requirements_with_jax_stable_stack.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
absl-py
aqtp==0.8.2
datasets
grain-nightly>=0.0.10
orbax-checkpoint>=0.10.3
pylint
pytest
pyink
Expand Down

0 comments on commit 01798e3

Please sign in to comment.