Skip to content

Commit

Permalink
add load_hf_weight for internlm and internlm2
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Aug 26, 2024
1 parent 06901e8 commit 83eb3a6
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 2 deletions.
98 changes: 97 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,103 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs):

@staticmethod
def load_hf_weights(folder: str, model: nn.Module):
raise NotImplementedError
"""NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config."""
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")

fns = get_fns(folder)
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") or fn.endswith(".safetensors")]
model_fns.sort()

states = {}

for model_fn in model_fns:
states.update(llm_load(model_fn, map_location="cpu"))

current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
layer_ids = i

# attn
q_proj_weight = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
k_proj_weight = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
v_proj_weight = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"blocks.{i}.mixer.wqkv.weight"] = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)

states[f"blocks.{i}.mixer.out_proj.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]

# mlp
states[f"blocks.{i}.mlp.w1.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"blocks.{i}.mlp.w3.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"blocks.{i}.mlp.w2.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]

# attn norm
states[f"blocks.{i}.norm1.weight"] = states.pop(f"model.layers.{layer_ids}.input_layernorm.weight")
# mlp norm
states[f"blocks.{i}.norm2.weight"] = states.pop(f"model.layers.{layer_ids}.post_attention_layernorm.weight")

for name in list(states.keys()):
if name.startswith(f"blocks.{i}"):
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)

model_state_keys = set(list(model.state_dict().keys()))

if "embedding.weight" in model_state_keys or "embedding.word_embeddings.weight" in model_state_keys:
if gpc.config.model.get("embed_split_hidden", True):
current_states["embedding.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
else:
current_states["embedding.word_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"

if "head.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"]
current_states["head.weight"] = torch.chunk(
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)]

missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)

internlm_accelerator.empty_cache()

@staticmethod
def load_internlm_with_dynamic_parallel_size(folder: str, model: nn.Module):
Expand Down
91 changes: 90 additions & 1 deletion internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) InternLM. All rights reserved.
import math
import os
from typing import Optional

import torch
from torch import nn

from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import (
Expand All @@ -25,7 +27,9 @@
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load

internlm_accelerator = get_accelerator()
logger = get_logger(__file__)


Expand Down Expand Up @@ -460,4 +464,89 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs):

@staticmethod
def load_hf_weights(folder: str, model: nn.Module):
raise NotImplementedError
"""NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config."""
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")

fns = get_fns(folder)
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") or fn.endswith(".safetensors")]
model_fns.sort()

states = {}

for model_fn in model_fns:
states.update(llm_load(model_fn, map_location="cpu"))

current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
layer_ids = i

# attn
states[f"layers.{i}.attention.wqkv.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.attention.wqkv.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wo.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.attention.wo.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]

# ffn
states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.feed_forward.w1.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.feed_forward.w3.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.feed_forward.w2.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]

# attn norm
states[f"layers.{i}.attention_norm.weight"] = states.pop(f"model.layers.{layer_ids}.attention_norm.weight")
# ffn norm
states[f"layers.{i}.ffn_norm.weight"] = states.pop(f"model.layers.{layer_ids}.ffn_norm.weight")

# replace value within decoder layer
for name in list(states.keys()):
if name.startswith(f"layers.{i}"):
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)

model_state_keys = set(list(model.state_dict().keys()))

if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys:
if gpc.config.model.get("embed_split_hidden", True):
current_states["tok_embeddings.weight"] = torch.chunk(
states["model.tok_embeddings.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
else:
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
states["model.tok_embeddings.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"

if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"]
current_states["output.weight"] = torch.chunk(
states["output.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)]

missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)

internlm_accelerator.empty_cache()
4 changes: 4 additions & 0 deletions internlm/utils/storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import multiprocessing
import os

from safetensors.torch import load_file

from internlm.utils.common import SingletonMeta

if "USE_DILL_PICKLE" in os.environ:
Expand Down Expand Up @@ -823,6 +825,8 @@ def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
@staticmethod
def load(load_path: str, **kwargs):
assert os.path.exists(load_path), f"{load_path} is not found!"
if load_path.endswith(".safetensors"):
return load_file(load_path)
with open(load_path, "rb") as f:
states = torch.load(f, **kwargs)
return states
Expand Down

0 comments on commit 83eb3a6

Please sign in to comment.