Skip to content

Commit

Permalink
fix hf checkpoint load funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Aug 26, 2024
1 parent f41bfc0 commit 3e1a30b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 86 deletions.
6 changes: 5 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,17 @@ def load_hf_weights(folder: str, model: nn.Module):
f"model.layers.{layer_ids}.post_attention_layernorm.weight"
)

# skip rotary_emb inv_freq
if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in state_dict:
state_dict.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")

# replace value within decoder layer
for name in list(state_dict.keys()):
if name.startswith(f"blocks.{i}"):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or (
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (
not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
):
new_state_dict["embedding.weight"] = torch.chunk(
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def load_hf_weights(folder: str, model: nn.Module):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or (
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (
not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
):
new_state_dict["tok_embeddings.weight"] = torch.chunk(
Expand Down
177 changes: 93 additions & 84 deletions internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load
from internlm.utils.utils import ModelType

internlm_accelerator = get_accelerator()
logger = get_logger(__file__)
Expand Down Expand Up @@ -462,96 +461,106 @@ def load_hf_weights(folder: str, model: nn.Module):
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") and fn.startswith("pytorch_model")]
model_fns.sort()

states = {}

state_dict = {}
for model_fn in model_fns:
states.update(llm_load(model_fn, map_location="cpu"))
state_dict.update(llm_load(model_fn, map_location="cpu"))

tp_size = gpc.get_world_size(ParallelMode.TENSOR)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
tp_mode = gpc.config.parallel.tensor["mode"]
split_size = wp_size if tp_mode == "isp" else tp_size
local_rank = wp_rank if tp_mode == "isp" else tp_rank
row_dim = 0 if tp_mode == "isp" else 1
if gpc.config.model.get("embed_split_hidden", True):
embed_concat_dim = 1
else:
embed_concat_dim = 0

deep_split = getattr(model, "deep_split", False)
if deep_split:
print("using deep split when loading pretrained weights!")
new_state_dict = {}

current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == ModelType.LLAMA2.name:
if deep_split:
layer_ids = i // 2
else:
layer_ids = i

if not deep_split or (i + 2) % 2 == 0:
states[f"layers.{i}.attention.wq.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wk.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wv.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wo.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention_norm.weight"] = states.pop(
f"model.layers.{layer_ids}.input_layernorm.weight"
)

if not deep_split or (i + 2) % 2 == 1:
states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]

states[f"layers.{i}.ffn_norm.weight"] = states.pop(
f"model.layers.{layer_ids}.post_attention_layernorm.weight"
)

if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states:
states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")

for name in list(states.keys()):
if name.startswith(f"layers.{i}"):
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)

model_state_keys = set(list(model.state_dict().keys()))
layer_ids = i

# attn
state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
split_size,
dim=row_dim,
)[local_rank]

# ffn
state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
split_size,
dim=row_dim,
)[local_rank]

# attn norm
state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
f"model.layers.{layer_ids}.input_layernorm.weight"
)

if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys:
if gpc.config.model.get("embed_split_hidden", True):
current_states["tok_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
else:
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
# ffn norm
state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
f"model.layers.{layer_ids}.post_attention_layernorm.weight"
)

if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"]
current_states["output.weight"] = torch.chunk(
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)]
# skip rotary_emb inv_freq
if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in state_dict:
state_dict.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")

missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
# replace value within decoder layer
for name in list(state_dict.keys()):
if name.startswith(f"layers.{i}"):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (
not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
):
new_state_dict["tok_embeddings.weight"] = torch.chunk(
state_dict.pop("model.embed_tokens.weight"),
split_size,
dim=embed_concat_dim,
)[local_rank]

# output
if gpc.is_last_rank(ParallelMode.PIPELINE):
new_state_dict["output.weight"] = torch.chunk(
state_dict.pop("lm_head.weight"),
split_size,
dim=0,
)[local_rank]
new_state_dict["norm.weight"] = state_dict["model.norm.weight"]

missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
Expand Down

0 comments on commit 3e1a30b

Please sign in to comment.