Skip to content

Commit

Permalink
refine hf load funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Aug 26, 2024
1 parent d712899 commit 0752f65
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,11 @@ def load_hf_weights(folder: str, model: nn.Module):
split_size,
dim=0,
)[local_rank]
new_state_dict["norm.weight"] = state_dict["model.norm.weight"]
new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")

missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if len(state_dict) == 0:
logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
Expand Down
4 changes: 3 additions & 1 deletion internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,11 @@ def load_hf_weights(folder: str, model: nn.Module):
split_size,
dim=0,
)[local_rank]
new_state_dict["norm.weight"] = state_dict["model.norm.weight"]
new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")

missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if len(state_dict) == 0:
logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
Expand Down
4 changes: 3 additions & 1 deletion internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,11 @@ def load_hf_weights(folder: str, model: nn.Module):
split_size,
dim=0,
)[local_rank]
new_state_dict["norm.weight"] = state_dict["model.norm.weight"]
new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")

missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if len(state_dict) == 0:
logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
Expand Down

0 comments on commit 0752f65

Please sign in to comment.