Skip to content

Commit

Permalink
feat(mha.py): support mla
Browse files Browse the repository at this point in the history
  • Loading branch information
BingyangWu committed Jan 2, 2025
1 parent 141e9eb commit d421847
Show file tree
Hide file tree
Showing 6 changed files with 680 additions and 50 deletions.
38 changes: 23 additions & 15 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
model_type = "INTERNLM2_PUBLIC"
DO_ALERT = False

VOCAB_SIZE = 103168
SEQ_LEN = 2048
VOCAB_SIZE = 50304
SEQ_LEN = 1024 * 1024

HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
NUM_KV_ATTENTION_HEAD = 8
MLP_RATIO = 8 / 3
NUM_KV_ATTENTION_HEAD = 32
MLP_RATIO = 3.5
NUM_LAYER = 32

# HIDDEN_SIZE = 8192
# NUM_ATTENTION_HEAD = 64
# NUM_KV_ATTENTION_HEAD = 64
# MLP_RATIO = 3.5
# NUM_LAYER = 80

MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
Expand Down Expand Up @@ -54,15 +60,15 @@
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
micro_num=1,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
micro_bsz=1,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
total_steps=20,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
Expand All @@ -76,7 +82,7 @@
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
# use_packed_dataset=False,
use_packed_dataset=False,
)

grad_scaler = dict(
Expand Down Expand Up @@ -152,9 +158,11 @@
)

use_fp32_norm = False
attention_type = "MLA"
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
Expand Down Expand Up @@ -222,15 +230,15 @@
interleaved the ranks in the same window to make full use of NIC as much as possible.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
zero1=dict(size=8),
tensor=dict(size=64, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="module"),
sequence_2D=dict(
enable=False,
head_size=2,
context_size=4,
window_size=1,
enable=True,
head_size=8,
context_size=8,
window_size=8,
device_placement_strategy=dict(head_first=True, interleaved=False),
),
)
Expand Down
42 changes: 29 additions & 13 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
enable_pytorch_expandable_segments,
get_current_device,
get_megatron_flops,
get_megatron_flops_mla,
launch_time,
)
from internlm.utils.gputest import empty_cache_and_diag
Expand Down Expand Up @@ -334,19 +335,34 @@ def _update_parameters(self):
return success_update, grad_norm_groups

def _record_metrics(self, batch_count: int, batch, start_time, loss, moe_loss, success_update, grad_norm_groups):
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.model["mlp_ratio"],
)
if gpc.config.get("attention_type", "GQA") == "MLA":
get_tflops_func = partial(
get_megatron_flops_mla,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.model["mlp_ratio"],
)
else:
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.model["mlp_ratio"],
)
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
logger=logger,
Expand Down
4 changes: 2 additions & 2 deletions internlm/data/tokenized/dummy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from torch.utils.data import Dataset

# from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.parallel_context import global_context as gpc


class RandomDataset(Dataset):
Expand All @@ -30,7 +30,7 @@ def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False)
while len(d) < max_len:
r *= 2
d = list(range(n)) * r
# r = r % gpc.config.model.vocab_size
r = r % gpc.config.model.vocab_size
d = [n, r] + d
d = d[:max_len]
data.append(d)
Expand Down
60 changes: 40 additions & 20 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from internlm.model.base_model import BaseModel
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.modules.mha import GQA
from internlm.model.modules.mha import GQA, MLA
from internlm.model.modules.mlp import new_feed_forward
from internlm.model.modules.norm import new_layer_norm
from internlm.model.utils import (
Expand Down Expand Up @@ -133,25 +133,45 @@ def __init__(
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope

head_dim = hidden_size // num_attention_heads
self.attention = GQA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
num_kv_heads=num_kv_attention_heads,
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=True,
)
if gpc.config.get("attention_type", "GQA") == "MLA":
self.attention = MLA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=True,
)
else:
self.attention = GQA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
num_kv_heads=num_kv_attention_heads,
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=True,
)

self.dropout1 = nn.Dropout(drop_rate)
self.dropout2 = nn.Dropout(drop_rate)
Expand Down
Loading

0 comments on commit d421847

Please sign in to comment.