From 83eb3a692a3483df432f5e5b1fa56923af3c47dd Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 23 Aug 2024 18:53:21 +0800 Subject: [PATCH] add load_hf_weight for internlm and internlm2 --- internlm/model/modeling_internlm.py | 98 +++++++++++++++++++++++++++- internlm/model/modeling_internlm2.py | 91 +++++++++++++++++++++++++- internlm/utils/storage_manager.py | 4 ++ 3 files changed, 191 insertions(+), 2 deletions(-) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 99b9fe4a..7ef72366 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -365,7 +365,103 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): @staticmethod def load_hf_weights(folder: str, model: nn.Module): - raise NotImplementedError + """NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config.""" + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") or fn.endswith(".safetensors")] + model_fns.sort() + + states = {} + + for model_fn in model_fns: + states.update(llm_load(model_fn, map_location="cpu")) + + current_states = {} + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + q_proj_weight = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + k_proj_weight = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + v_proj_weight = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"blocks.{i}.mixer.wqkv.weight"] = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0) + + states[f"blocks.{i}.mixer.out_proj.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=1, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + # mlp + states[f"blocks.{i}.mlp.w1.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"blocks.{i}.mlp.w3.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"blocks.{i}.mlp.w2.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=1, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + # attn norm + states[f"blocks.{i}.norm1.weight"] = states.pop(f"model.layers.{layer_ids}.input_layernorm.weight") + # mlp norm + states[f"blocks.{i}.norm2.weight"] = states.pop(f"model.layers.{layer_ids}.post_attention_layernorm.weight") + + for name in list(states.keys()): + if name.startswith(f"blocks.{i}"): + current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) + + model_state_keys = set(list(model.state_dict().keys())) + + if "embedding.weight" in model_state_keys or "embedding.word_embeddings.weight" in model_state_keys: + if gpc.config.model.get("embed_split_hidden", True): + current_states["embedding.weight"] = torch.chunk( + states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + else: + current_states["embedding.word_embeddings.weight"] = torch.chunk( + states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}" + + if "head.weight" in model_state_keys: + current_states["norm.weight"] = states["model.norm.weight"] + current_states["head.weight"] = torch.chunk( + states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() @staticmethod def load_internlm_with_dynamic_parallel_size(folder: str, model: nn.Module): diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 998e315c..a910d9f2 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -1,10 +1,12 @@ # Copyright (c) InternLM. All rights reserved. import math +import os from typing import Optional import torch from torch import nn +from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import ( @@ -25,7 +27,9 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load +internlm_accelerator = get_accelerator() logger = get_logger(__file__) @@ -460,4 +464,89 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): @staticmethod def load_hf_weights(folder: str, model: nn.Module): - raise NotImplementedError + """NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config.""" + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") or fn.endswith(".safetensors")] + model_fns.sort() + + states = {} + + for model_fn in model_fns: + states.update(llm_load(model_fn, map_location="cpu")) + + current_states = {} + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + states[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.attention.wqkv.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.attention.wo.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.attention.wo.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=1, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + # ffn + states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.feed_forward.w1.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.feed_forward.w3.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.feed_forward.w2.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=1, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + # attn norm + states[f"layers.{i}.attention_norm.weight"] = states.pop(f"model.layers.{layer_ids}.attention_norm.weight") + # ffn norm + states[f"layers.{i}.ffn_norm.weight"] = states.pop(f"model.layers.{layer_ids}.ffn_norm.weight") + + # replace value within decoder layer + for name in list(states.keys()): + if name.startswith(f"layers.{i}"): + current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) + + model_state_keys = set(list(model.state_dict().keys())) + + if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys: + if gpc.config.model.get("embed_split_hidden", True): + current_states["tok_embeddings.weight"] = torch.chunk( + states["model.tok_embeddings.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + else: + current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk( + states["model.tok_embeddings.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}" + + if "output.weight" in model_state_keys: + current_states["norm.weight"] = states["model.norm.weight"] + current_states["output.weight"] = torch.chunk( + states["output.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 6aa1ebd1..6f02d9ba 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -4,6 +4,8 @@ import multiprocessing import os +from safetensors.torch import load_file + from internlm.utils.common import SingletonMeta if "USE_DILL_PICKLE" in os.environ: @@ -823,6 +825,8 @@ def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs): @staticmethod def load(load_path: str, **kwargs): assert os.path.exists(load_path), f"{load_path} is not found!" + if load_path.endswith(".safetensors"): + return load_file(load_path) with open(load_path, "rb") as f: states = torch.load(f, **kwargs) return states