Skip to content

Commit

Permalink
Stop writing msgpack file for new checkpoints and update empty nodes …
Browse files Browse the repository at this point in the history
…handling so that it no longer depends on this file.

PiperOrigin-RevId: 649252891
  • Loading branch information
dubey authored and Flax Authors committed Jul 4, 2024
1 parent 0fb1777 commit 3171f15
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from collections.abc import Callable, Iterable

from etils import epath
import jax
import orbax.checkpoint as ocp
from absl import logging
Expand Down Expand Up @@ -77,7 +76,7 @@

# Orbax main checkpoint file name.
ORBAX_CKPT_FILENAME = 'checkpoint'
ORBAX_METADATA_FILENAME = '_METADATA'
ORBAX_MANIFEST_OCDBT = 'manifest.ocdbt'

PyTree = Any

Expand Down Expand Up @@ -124,8 +123,7 @@ def _safe_remove(path: str):

def _is_orbax_checkpoint(path: str) -> bool:
return io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists(
os.path.join(path, ORBAX_METADATA_FILENAME)
or ocp.type_handlers.is_ocdbt_checkpoint(epath.Path(path))
os.path.join(path, ORBAX_MANIFEST_OCDBT)
)


Expand Down

0 comments on commit 3171f15

Please sign in to comment.