diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index 8583dfa6b..efe4dd5d6 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -126,7 +126,7 @@ def create_orbax_emergency_replicator_checkpoint_manager( save_interval_steps=save_interval_steps, ) manager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager( - local_checkpoint_dir, + epath.Path(local_checkpoint_dir), options, global_mesh=global_mesh, ) diff --git a/MaxText/train.py b/MaxText/train.py index b52e93ee3..8c3ed1331 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -960,6 +960,7 @@ def train_loop(config, state=None): record_goodput(recorder, config, recorder.record_job_end_time if recorder else None) clear_buffered_metrics() with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + # pytype: disable=attribute-error compiled = p_train_step.lower(state, example_batch, nextrng).compile() compiled_stats = compiled.memory_analysis() max_logging.log(