Skip to content

Commit

Permalink
fix(mlp.py): swap mlp w1w2w3 init order to w1w3w2 and fix QA (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Dec 6, 2024
1 parent 902b632 commit f6c66bd
Show file tree
Hide file tree
Showing 12 changed files with 392 additions and 115 deletions.
24 changes: 2 additions & 22 deletions .github/workflows/e2e_test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: e2e-tests
on:
on:
pull_request:
branches:
- "develop"
Expand Down Expand Up @@ -232,24 +232,4 @@ jobs:
jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT}
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_llama2" ./tests/test_training/test_loss.py
exit_code=$?
sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname
training_internlm2:
strategy:
matrix:
runner: [t_cluster]
runs-on: ${{ matrix.runner }}
timeout-minutes: 20
steps:
- name: mask env
run: |
echo "::add-mask::${{env.WORKSPACE_PREFIX}}"
echo "::add-mask::$path_prefix"
- uses: actions/checkout@v3
- name: training_internlm2_T
run: |
source activate ${evo_env_torch21_flash2}
jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT}
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_internlm2" ./tests/test_training/test_loss.py
exit_code=$?
sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname
sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname
12 changes: 6 additions & 6 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
JOB_NAME = "7b_train"
# model_type = "INTERNLM2_PUBLIC"
model_type = "INTERNLM2_PUBLIC"
DO_ALERT = False

VOCAB_SIZE = 103168
Expand Down Expand Up @@ -31,7 +31,7 @@
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# load function such as "llama"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
Expand Down Expand Up @@ -145,7 +145,7 @@
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
# no_bias=True,
no_bias=True,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
Expand Down Expand Up @@ -188,17 +188,17 @@
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
sequence_2D (dict):
1. enable: bool, whether enable the 2D sequence parallel or not.
2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses).
2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses).
head_size * context_size should be equal tensor size.
3. context_size: int, the parallel degree of context parallelism.
head_size * context_size should be equal tensor size.
4. window_size: int, the sliding window size in context parallelism.
5. device_placement_strategy: dict,
head_first: bool, if `True`, ranks of the same head parallel group are
head_first: bool, if `True`, ranks of the same head parallel group are
given high priority for colocation on the same node;
if `False`, ranks of the same context parallel group are
given high priority for colocation on the same node;
interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could
interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could
interleaved the ranks in the same window to make full use of NIC as much as possible.
"""
parallel = dict(
Expand Down
2 changes: 2 additions & 0 deletions internlm/checkpoint/load_funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) InternLM. All rights reserved.

from internlm.model.modeling_internlm import InternLM1
from internlm.model.modeling_internlm2 import InternLM2
from internlm.model.modeling_llama import Llama2
from internlm.utils.logger import get_logger

Expand All @@ -9,4 +10,5 @@
LOAD_FUNC_DICT = {
"llama": Llama2.load_llama_pretrained_weights,
"internlm_test": InternLM1.load_internlm_with_dynamic_parallel_size,
"internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size,
}
193 changes: 193 additions & 0 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) InternLM. All rights reserved.
import math
import os
from functools import reduce
from typing import Optional

import torch
Expand All @@ -11,6 +12,7 @@
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.parallel.shard import partition_uniform
from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
Expand All @@ -26,6 +28,7 @@
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
get_parallel_size_from_file,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -576,6 +579,196 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:

internlm_accelerator.empty_cache()

@staticmethod
def load_internlm2_with_dynamic_parallel_size(folder, model):
"""Load InternLM2 with dynamic parallel size."""
assert folder is not None, "Please specify the folder of the pretrained model"
assert gpc.config.model_type in ["INTERNLM2_PUBLIC"], "dynamic_parallel is only for INTERNLM2_PUBLIC"

fns = get_fns(folder)
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")
model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612

tp = gpc.get_world_size(ParallelMode.TENSOR)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
assert old_tp % tp == 0 or tp % old_tp == 0, (
f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in "
f"checkpoint and {tp} in current config"
)

correspond_tps = []

if old_tp <= tp:
correspond_tps.append(tp_rank // (tp // old_tp))
ratio = tp // old_tp
rank = tp_rank % ratio
else:
for i in range(old_tp // tp):
correspond_tps.append(tp_rank * (old_tp // tp) + i)
rank = 0
ratio = 1

current_states = {}

pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612

assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary"

old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1)

for idx, parts in enumerate(old_pp_partition):
start, end = parts[0]
if model.last_layer <= start or model.first_layer >= end:
continue
tmp_states = {}

for correspond_tp in correspond_tps:
model_name = f"model_tp{correspond_tp}_pp{idx}.pt"
states = llm_load(os.path.join(folder, model_name), map_location="cpu")
states = {k.replace("model.", ""): v for k, v in states.items()}
for i in range(start, end):
if i >= model.last_layer:
break
if i < model.first_layer:
continue

for name in list(states.keys()):
if f".{i-start}." in name:
to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.")

if gpc.config.model_type == "INTERNLM2_PUBLIC":
if "norm" in name:
tmp_states[to_name] = [states.pop(name)]
elif any(x in name for x in ("wo", "w2")):
tmp_states[to_name] = tmp_states.get(to_name, [])
tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank])
elif any(x in name for x in ("w1", "w3")):
tmp_states[to_name] = tmp_states.get(to_name, [])
tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
elif any(x in name for x in ("wqkv",)):
tmp_states[to_name] = tmp_states.get(to_name, [])
if tp > gpc.config.model.num_kv_attention_heads:
assert old_tp <= gpc.config.model.num_kv_attention_heads, (
f"`old_tp ({old_tp}) => tp ({tp})` is not supported. "
"At least one of `tp` and `old_tp` should be less than or "
"equal to `num_kv_attention_heads`"
)
# Suitable for cases where the num_kv_attention_head is small,
# but you want to have a large TP Size
q_per_kv = (
gpc.config.model.num_attention_heads
// gpc.config.model.num_kv_attention_heads
)
head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
index = torch.concat(
(
torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio],
torch.tensor([q_per_kv, q_per_kv + 1]),
)
)
index = index + (q_per_kv + 2) * (tp_rank // ratio)
index = index % (
(q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp)
)
index = index * head_dim
index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
index.shape[0]
)
tmp_states[to_name].append(
torch.index_select(states.pop(name), 0, index.to(torch.int32))
)
else:
tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
else:
raise KeyError(f"Unknown key {name}.")

else:
assert False, "unsupported model type"

if "tok_embeddings.weight" in states and model.first_layer == 0:
tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", [])
tmp_states["tok_embeddings.weight"].append(
states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank]
)
if "output.weight" in states and model.last_layer == gpc.config.model.num_layers:
tmp_states["norm.weight"] = [states["norm.weight"]]
tmp_states["output.weight"] = tmp_states.get("output.weight", [])
tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank])

states = {}

for name in list(tmp_states.keys()):
data = tmp_states.pop(name)
if len(data) == 1:
current_states[name] = data[0]
else:
current_states[name] = torch.concat(
data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0
)
# Merge copied kv heads
if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads:
assert (
tp <= gpc.config.model.num_kv_attention_heads
), "new_tp should be less than or equal to num_kv_attention_heads"
head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads
copied_times = old_tp // gpc.config.model.num_kv_attention_heads
cur_q_per_kv = q_per_kv // copied_times

# pylint: disable=all
def duplicate_kv_index(i):
if i % (cur_q_per_kv + 2) >= cur_q_per_kv:
return i
else:
return -100

def unique_kv_index(i):
if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv:
return i
else:
return -100

# pylint: enable=all

# Verify
duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
duplicate_index = [i for i in duplicate_index if i != -100]
duplicate_index = _duplicate_index = torch.tensor(duplicate_index)
for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
duplicate_index = torch.concat(
(duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0
)
duplicate_kv = []
for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1):
index = index.reshape(-1) * head_dim
index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0])
duplicate_kv.append(torch.index_select(current_states[name], 0, index))
assert reduce(
lambda x, y: x and y,
[torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]],
), "Copied kv heads are not equal after training!"

# Merge
unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
unique_index = [i for i in unique_index if i != -100]
unique_index = _unique_index = torch.tensor(unique_index)
for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0)
unique_index = unique_index * head_dim
unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
unique_index.shape[0]
)
current_states[name] = torch.index_select(current_states[name], 0, unique_index)
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)

@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True):
Expand Down
18 changes: 9 additions & 9 deletions internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def __init__(
self.w1 = new_linear(
"w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert
)
self.w2 = new_linear(
"w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert
)
self.w3 = new_linear(
"w3", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert
)
self.w2 = new_linear(
"w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert
)

def forward(self, x):
if not self.mlp_layer_fusion:
Expand Down Expand Up @@ -177,21 +177,21 @@ def __init__(
backend=backend,
is_expert=is_expert,
)
self.w2 = new_linear(
"grouped_w2",
self.w3 = new_linear(
"grouped_w3",
in_features,
hidden_features,
out_features,
bias,
device=device,
dtype=dtype,
num_groups=num_groups,
backend=backend,
is_expert=is_expert,
)
self.w3 = new_linear(
"grouped_w3",
in_features,
self.w2 = new_linear(
"grouped_w2",
hidden_features,
out_features,
bias,
device=device,
dtype=dtype,
Expand Down
Loading

0 comments on commit f6c66bd

Please sign in to comment.