Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/loong train mla #400

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
model_type = "INTERNLM2_PUBLIC"
DO_ALERT = False

VOCAB_SIZE = 103168
SEQ_LEN = 2048
VOCAB_SIZE = 50304
SEQ_LEN = 1024 * 1024

HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
NUM_KV_ATTENTION_HEAD = 8
MLP_RATIO = 8 / 3
NUM_KV_ATTENTION_HEAD = 32
MLP_RATIO = 3.5
NUM_LAYER = 32

# HIDDEN_SIZE = 8192
# NUM_ATTENTION_HEAD = 64
# NUM_KV_ATTENTION_HEAD = 64
# MLP_RATIO = 3.5
# NUM_LAYER = 80

MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
Expand Down Expand Up @@ -54,15 +60,15 @@
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
micro_num=1,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
micro_bsz=1,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
total_steps=20,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
Expand All @@ -76,7 +82,7 @@
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
# use_packed_dataset=False,
use_packed_dataset=False,
)

grad_scaler = dict(
Expand Down Expand Up @@ -152,9 +158,11 @@
)

use_fp32_norm = False
attention_type = "MLA"
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
Expand Down Expand Up @@ -222,15 +230,15 @@
interleaved the ranks in the same window to make full use of NIC as much as possible.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
zero1=dict(size=8),
tensor=dict(size=64, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="module"),
sequence_2D=dict(
enable=False,
head_size=2,
context_size=4,
window_size=1,
enable=True,
head_size=8,
context_size=8,
window_size=8,
device_placement_strategy=dict(head_first=True, interleaved=False),
),
)
Expand Down
42 changes: 29 additions & 13 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
enable_pytorch_expandable_segments,
get_current_device,
get_megatron_flops,
get_megatron_flops_mla,
launch_time,
)
from internlm.utils.gputest import empty_cache_and_diag
Expand Down Expand Up @@ -338,19 +339,34 @@ def _update_parameters(self):
return success_update, grad_norm_groups

def _record_metrics(self, batch_count: int, batch, start_time, loss, moe_loss, success_update, grad_norm_groups):
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.model["mlp_ratio"],
)
if gpc.config.get("attention_type", "GQA") == "MLA":
get_tflops_func = partial(
get_megatron_flops_mla,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.model["mlp_ratio"],
)
else:
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.model["mlp_ratio"],
)
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
logger=logger,
Expand Down
4 changes: 2 additions & 2 deletions internlm/data/tokenized/dummy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from torch.utils.data import Dataset

# from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.parallel_context import global_context as gpc


class RandomDataset(Dataset):
Expand All @@ -30,7 +30,7 @@ def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False)
while len(d) < max_len:
r *= 2
d = list(range(n)) * r
# r = r % gpc.config.model.vocab_size
r = r % gpc.config.model.vocab_size
d = [n, r] + d
d = d[:max_len]
data.append(d)
Expand Down
60 changes: 40 additions & 20 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from internlm.model.base_model import BaseModel
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.modules.mha import GQA
from internlm.model.modules.mha import GQA, MLA
from internlm.model.modules.mlp import new_feed_forward
from internlm.model.modules.norm import new_layer_norm
from internlm.model.utils import (
Expand Down Expand Up @@ -133,25 +133,45 @@ def __init__(
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope

head_dim = hidden_size // num_attention_heads
self.attention = GQA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
num_kv_heads=num_kv_attention_heads,
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=True,
)
if gpc.config.get("attention_type", "GQA") == "MLA":
self.attention = MLA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=True,
)
else:
self.attention = GQA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
num_kv_heads=num_kv_attention_heads,
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=True,
)

self.dropout1 = nn.Dropout(drop_rate)
self.dropout2 = nn.Dropout(drop_rate)
Expand Down
Loading
Loading