Skip to content

Commit

Permalink
enable loading older TE checkpoints (#11930)
Browse files Browse the repository at this point in the history
* set strict to False

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add load_legacy flag

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* revert config changes

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* rename variable

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* remove extra flag usage

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* change error type

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix warning

Signed-off-by: dimapihtar <dpihtar@gmail.com>

---------

Signed-off-by: dimapihtar <dpihtar@gmail.com>
Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>
Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com>
  • Loading branch information
dimapihtar and dimapihtar authored Jan 24, 2025
1 parent cc365b6 commit 21e94fc
Showing 1 changed file with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,16 @@ def on_load_checkpoint(self, checkpoint) -> None:
key.replace('model.', ''): checkpoint_state_dict.pop(key)
for key in list(checkpoint_state_dict.keys())
}
module.load_state_dict(checkpoint_state_dict, strict=True)
try:
module.load_state_dict(checkpoint_state_dict, strict=True)
except RuntimeError as e:
missing_keys, expected_keys = module.load_state_dict(checkpoint_state_dict, strict=False)
if all(s.endswith('_extra_state') for s in missing_keys):
logging.warning(
f'Loding checkpoint created with Transformer Engine version lower than 1.13. Missing layers {missing_keys} will be ignored.'
)
else:
raise e
else:
# when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict
# see NLPModel.on_load_checkpoint
Expand Down

0 comments on commit 21e94fc

Please sign in to comment.