Skip to content

Commit

Permalink
fix(ckpt): fix load funcs when loading llama & hf_llama (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 authored Mar 14, 2024
1 parent 7873d1c commit a6a3235
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions internlm/utils/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,22 @@ def save_model_checkpoint(folder, model):


def load_llama_pretrained_weights(folder, model):
model = model.model
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")

fns = get_fns(folder)
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]
model_fns = []
for fn in fns:
if fn.startswith("model_t") and not fn.endswith("md5"):
model_fns.append(os.path.join(folder, fn))

if len(model_fns) == 0:
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]

if len(model_fns) == 0:
raise FileNotFoundError(f"No checkpoint file found in {folder}")

model_fns.sort()

old_tp = len(model_fns)
Expand All @@ -391,14 +400,6 @@ def load_llama_pretrained_weights(folder, model):

current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == "LLAMA2":
# LLAMA's w2 and w3 are in reverse order
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
states[f"layers.{i}.feed_forward.w2.weight"] = w3
states[f"layers.{i}.feed_forward.w3.weight"] = w2
if "rope.freqs" in states:
states[f"layers.{i}.attention.rotary_emb.inv_freq"] = states["rope.freqs"]
for name in list(states.keys()):
if f".{i}." in name:
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
Expand Down Expand Up @@ -426,7 +427,7 @@ def load_llama_pretrained_weights(folder, model):


def load_hf_llama_pretrained_weights(folder, model):
model = model.model
"""NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config."""
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")
Expand Down Expand Up @@ -483,12 +484,12 @@ def load_hf_llama_pretrained_weights(folder, model):
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
Expand All @@ -500,6 +501,13 @@ def load_hf_llama_pretrained_weights(folder, model):

if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states:
states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")

if gpc.config.model_type in ("LLAMA2",):
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
states[f"layers.{i}.feed_forward.w2.weight"] = w3
states[f"layers.{i}.feed_forward.w3.weight"] = w2

for name in list(states.keys()):
if name.startswith(f"layers.{i}"):
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
Expand All @@ -516,6 +524,7 @@ def load_hf_llama_pretrained_weights(folder, model):
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"

if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"]
current_states["output.weight"] = torch.chunk(
Expand Down

0 comments on commit a6a3235

Please sign in to comment.