From 3e1a30bb6682f4b0e2bddfaaf6095e0a45521fec Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 26 Aug 2024 16:59:36 +0800 Subject: [PATCH] fix hf checkpoint load funcs --- internlm/model/modeling_internlm.py | 6 +- internlm/model/modeling_internlm2.py | 2 +- internlm/model/modeling_llama.py | 177 ++++++++++++++------------- 3 files changed, 99 insertions(+), 86 deletions(-) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 3fc3a422..c6599466 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -462,13 +462,17 @@ def load_hf_weights(folder: str, model: nn.Module): f"model.layers.{layer_ids}.post_attention_layernorm.weight" ) + # skip rotary_emb inv_freq + if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in state_dict: + state_dict.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") + # replace value within decoder layer for name in list(state_dict.keys()): if name.startswith(f"blocks.{i}"): new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) # embedding - if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or ( + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or ( not gpc.is_using_parallel_mode(ParallelMode.PIPELINE) ): new_state_dict["embedding.weight"] = torch.chunk( diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 71a58653..cd25313d 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -537,7 +537,7 @@ def load_hf_weights(folder: str, model: nn.Module): new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) # embedding - if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or ( + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or ( not gpc.is_using_parallel_mode(ParallelMode.PIPELINE) ): new_state_dict["tok_embeddings.weight"] = torch.chunk( diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 63b6db4f..f614dfe5 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -28,7 +28,6 @@ from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load -from internlm.utils.utils import ModelType internlm_accelerator = get_accelerator() logger = get_logger(__file__) @@ -462,96 +461,106 @@ def load_hf_weights(folder: str, model: nn.Module): model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") and fn.startswith("pytorch_model")] model_fns.sort() - states = {} - + state_dict = {} for model_fn in model_fns: - states.update(llm_load(model_fn, map_location="cpu")) + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 - deep_split = getattr(model, "deep_split", False) - if deep_split: - print("using deep split when loading pretrained weights!") + new_state_dict = {} - current_states = {} for idx, i in enumerate(range(model.first_layer, model.last_layer)): - if gpc.config.model_type == ModelType.LLAMA2.name: - if deep_split: - layer_ids = i // 2 - else: - layer_ids = i - - if not deep_split or (i + 2) % 2 == 0: - states[f"layers.{i}.attention.wq.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=0, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - states[f"layers.{i}.attention.wk.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=0, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - states[f"layers.{i}.attention.wv.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=0, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - states[f"layers.{i}.attention.wo.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=1, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - states[f"layers.{i}.attention_norm.weight"] = states.pop( - f"model.layers.{layer_ids}.input_layernorm.weight" - ) - - if not deep_split or (i + 2) % 2 == 1: - states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=0, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=0, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( - states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), - gpc.get_world_size(ParallelMode.TENSOR), - dim=1, - )[gpc.get_local_rank(ParallelMode.TENSOR)] - - states[f"layers.{i}.ffn_norm.weight"] = states.pop( - f"model.layers.{layer_ids}.post_attention_layernorm.weight" - ) - - if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states: - states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") - - for name in list(states.keys()): - if name.startswith(f"layers.{i}"): - current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) - - model_state_keys = set(list(model.state_dict().keys())) + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) - if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys: - if gpc.config.model.get("embed_split_hidden", True): - current_states["tok_embeddings.weight"] = torch.chunk( - states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 - )[gpc.get_local_rank(ParallelMode.TENSOR)] - else: - current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk( - states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 - )[gpc.get_local_rank(ParallelMode.TENSOR)] - assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}" + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) - if "output.weight" in model_state_keys: - current_states["norm.weight"] = states["model.norm.weight"] - current_states["output.weight"] = torch.chunk( - states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 - )[gpc.get_local_rank(ParallelMode.TENSOR)] + # skip rotary_emb inv_freq + if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in state_dict: + state_dict.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") - missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or ( + not gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + ): + new_state_dict["tok_embeddings.weight"] = torch.chunk( + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict["model.norm.weight"] + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) if gpc.get_local_rank(ParallelMode.DATA) == 0: pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)