From d421847498acd566a0a907084dbddba592a30c1a Mon Sep 17 00:00:00 2001 From: BingyangWu Date: Thu, 2 Jan 2025 11:45:30 +0800 Subject: [PATCH] feat(mha.py): support mla --- configs/7B_isp_sft.py | 38 +- internlm/core/trainer_builder.py | 42 +- internlm/data/tokenized/dummy_dataset.py | 4 +- internlm/model/modeling_internlm2.py | 60 ++- internlm/model/modules/mha.py | 474 +++++++++++++++++++++++ internlm/utils/common.py | 112 ++++++ 6 files changed, 680 insertions(+), 50 deletions(-) diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 39c78660..9faaa90d 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -2,14 +2,20 @@ model_type = "INTERNLM2_PUBLIC" DO_ALERT = False -VOCAB_SIZE = 103168 -SEQ_LEN = 2048 +VOCAB_SIZE = 50304 +SEQ_LEN = 1024 * 1024 + HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 -NUM_KV_ATTENTION_HEAD = 8 -MLP_RATIO = 8 / 3 +NUM_KV_ATTENTION_HEAD = 32 +MLP_RATIO = 3.5 NUM_LAYER = 32 +# HIDDEN_SIZE = 8192 +# NUM_ATTENTION_HEAD = 64 +# NUM_KV_ATTENTION_HEAD = 64 +# MLP_RATIO = 3.5 +# NUM_LAYER = 80 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" # Ckpt folder format: @@ -54,15 +60,15 @@ data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, + micro_num=1, # packed_length = micro_bsz * SEQ_LEN - micro_bsz=2, + micro_bsz=1, # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate valid_every=50, pack_sample_into_one=False, - total_steps=50000, + total_steps=20, skip_batches="", # rampup_batch_size (str): A string with three space-separated integers representing the # starting batch size, the increment, and the number of steps between @@ -76,7 +82,7 @@ valid_folder=VALID_FOLDER, empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, - # use_packed_dataset=False, + use_packed_dataset=False, ) grad_scaler = dict( @@ -152,9 +158,11 @@ ) use_fp32_norm = False +attention_type = "MLA" model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, @@ -222,15 +230,15 @@ interleaved the ranks in the same window to make full use of NIC as much as possible. """ parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=2, mode="isp"), + zero1=dict(size=8), + tensor=dict(size=64, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), + weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="module"), sequence_2D=dict( - enable=False, - head_size=2, - context_size=4, - window_size=1, + enable=True, + head_size=8, + context_size=8, + window_size=8, device_placement_strategy=dict(head_first=True, interleaved=False), ), ) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 2b82bc1f..dfe64373 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -35,6 +35,7 @@ enable_pytorch_expandable_segments, get_current_device, get_megatron_flops, + get_megatron_flops_mla, launch_time, ) from internlm.utils.gputest import empty_cache_and_diag @@ -334,19 +335,34 @@ def _update_parameters(self): return success_update, grad_norm_groups def _record_metrics(self, batch_count: int, batch, start_time, loss, moe_loss, success_update, grad_norm_groups): - get_tflops_func = partial( - get_megatron_flops, - checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.data["seq_len"], - hidden_size=gpc.config.model.hidden_size, - num_layers=gpc.config.model.num_layers, - vocab_size=gpc.config.model.vocab_size, - global_batch_size=gpc.config.data.micro_bsz - * gpc.config.data.micro_num - * gpc.get_world_size(ParallelMode.DATA), - global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.model["mlp_ratio"], - ) + if gpc.config.get("attention_type", "GQA") == "MLA": + get_tflops_func = partial( + get_megatron_flops_mla, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.data["seq_len"], + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.model["mlp_ratio"], + ) + else: + get_tflops_func = partial( + get_megatron_flops, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.data["seq_len"], + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.model["mlp_ratio"], + ) record_current_batch_training_metrics( get_tflops_func=get_tflops_func, logger=logger, diff --git a/internlm/data/tokenized/dummy_dataset.py b/internlm/data/tokenized/dummy_dataset.py index dcb6c027..1e64e00a 100644 --- a/internlm/data/tokenized/dummy_dataset.py +++ b/internlm/data/tokenized/dummy_dataset.py @@ -4,7 +4,7 @@ import numpy as np from torch.utils.data import Dataset -# from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc class RandomDataset(Dataset): @@ -30,7 +30,7 @@ def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False) while len(d) < max_len: r *= 2 d = list(range(n)) * r - # r = r % gpc.config.model.vocab_size + r = r % gpc.config.model.vocab_size d = [n, r] + d d = d[:max_len] data.append(d) diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 69da0837..d731a5e8 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -22,7 +22,7 @@ from internlm.model.base_model import BaseModel from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import GQA +from internlm.model.modules.mha import GQA, MLA from internlm.model.modules.mlp import new_feed_forward from internlm.model.modules.norm import new_layer_norm from internlm.model.utils import ( @@ -133,25 +133,45 @@ def __init__( self.use_dynamic_ntk_rope = use_dynamic_ntk_rope head_dim = hidden_size // num_attention_heads - self.attention = GQA( - embed_dim=hidden_size, - num_heads=num_attention_heads, - num_kv_heads=num_kv_attention_heads, - dropout=attn_drop_rate, - max_position_embeddings=max_position_embeddings, - softmax_scale=1 / math.sqrt(head_dim), - causal=True, - layer_idx=layer_idx, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - rotary_emb_dim=head_dim, - rotary_emb_scale_base=0, - device=device, - dtype=dtype, - qk_interleaved=qk_interleaved, - bias=not no_bias, - rope_base=rope_base, - enable_qkv_fusion=True, - ) + if gpc.config.get("attention_type", "GQA") == "MLA": + self.attention = MLA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + bias=not no_bias, + rope_base=rope_base, + enable_qkv_fusion=True, + ) + else: + self.attention = GQA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + bias=not no_bias, + rope_base=rope_base, + enable_qkv_fusion=True, + ) self.dropout1 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a21..f7121f11 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -13,10 +13,13 @@ from internlm.core.context import global_context as gpc from internlm.model.modules.embedding import new_rotary_embedding from internlm.model.modules.linear import new_linear +from internlm.model.modules.norm import new_layer_norm from internlm.model.modules.utils import update_kv_cache from internlm.model.ops.attention import CrossAttention, SelfAttention from internlm.utils.logger import get_logger +import dlblas + logger = get_logger(__file__) @@ -81,6 +84,477 @@ def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) - return state_dict +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2RotaryEmbedding(nn.Module): + # pylint: disable=missing-class-docstring + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + # pylint: disable=missing-class-docstring + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class MLA(nn.Module): + """ + Multi-Latent self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + max_position_embeddings (int): max position embeddings, 2048 by default. + bias (bool): Whether the bias is needed for linears. True by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. + enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + max_position_embeddings: int = 2048, + bias: bool = True, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + rope_base: int = 10000, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + qk_interleaved: Optional[bool] = True, + enable_qkv_fusion: bool = True, + out_bias: bool = True, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.causal = causal + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // num_heads + self.kv_dim = self.head_dim * num_heads # num_kv_heads equals to num_heads in MHA + self.enable_qkv_fusion = enable_qkv_fusion + + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.rotary_emb_dim = rotary_emb_dim + self.max_position_embeddings = max_position_embeddings + self.interleaved = qk_interleaved + + # MLA config + self.q_lora_rank = 1536 + self.kv_lora_rank = 512 + self.v_head_dim = 128 + self.qk_nope_head_dim = 128 + self.qk_rope_head_dim = 64 + self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + + factory_kwargs = {"device": device, "dtype": dtype} + + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + + self.yarn_embed = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_base, + ) + + self.q_a_layernorm = new_layer_norm("rmsnorm", self.q_lora_rank, eps=1e-6) + self.kv_a_layernorm = new_layer_norm("rmsnorm", self.kv_lora_rank, eps=1e-6) + self.q_a_proj = new_linear( + "wqkv", + embed_dim, + self.q_lora_rank, + bias=bias, + **factory_kwargs, + ) + self.q_b_proj = new_linear( + "wqkv", + self.q_lora_rank, + self.num_heads * self.q_head_dim, + bias=bias, + **factory_kwargs, + ) + self.kv_a_proj_with_mqa = new_linear( + "wqkv", + embed_dim, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=bias, + **factory_kwargs, + ) + self.kv_b_proj = new_linear( + "wqkv", + self.kv_lora_rank, + self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=bias, + **factory_kwargs, + ) + + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + # output projection always have the bias (for now) (except for baichuan2 model) + self.out_proj = new_linear( + "out_proj", self.num_heads * self.v_head_dim, embed_dim, bias=out_bias, **factory_kwargs + ) + + def register_checkpoint_compatibility_hooks( + self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None + ): + # Here we explicitly expose the checkpoint compatibility interface of the module, + # hoping that model developers will make good use of it when adapting. + # Is this interface already meeting all reasonable requirements? + self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) + self._register_state_dict_hook(pre_save_hook) + + def forward(self, x, inference_params=None, **kwargs): + if inference_params is None: + return self._training(x=x, **kwargs) + else: + return self._inference(x=x, inference_params=inference_params, **kwargs) + + def _training(self, x, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + """ + # print(f"ht debug MLA forward...", flush=True) + + bsz, q_len, _ = x.size() + + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) + + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view( + bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + + # rotary embedding + cos, sin = self.yarn_embed(q, seq_len=kwargs["max_seqlen"]) + + import dlblas.kernels.partial_rotary_emb as partial_rotary_emb + + position_ids = kwargs.pop("indexes", torch.arange(0, q_len)).to(q.device).unsqueeze(0) + # q, kv = dlblas.partial_rotary_emb( + # q.contiguous(), + # k_pe.contiguous(), + # kv.contiguous(), + # cos[position_ids].contiguous(), + # sin[position_ids].contiguous(), + # ) + q, kv = partial_rotary_emb.PartialRotaryEmb.apply( + q.contiguous(), + k_pe.contiguous(), + kv.contiguous(), + cos[position_ids].contiguous(), + sin[position_ids].contiguous(), + ) + + # self attention + kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + if gpc.config.data.use_packed_dataset is False or self.training is False: + kwargs.pop("max_seqlen_q", None) + kwargs.pop("max_seqlen_k", None) + + context = self.inner_attn(q, kv, **kwargs) + + if self.q_head_dim != self.v_head_dim: + context = context[:, :, :, : self.v_head_dim] + + # wo + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + + def _convert_unpacked_qkv_to_packed( + self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor + ): + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attention_mask.device), + attention_mask.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ).cumsum(dim=0, dtype=torch.int32) + + cu_seqlens_q = cu_seqlens + cu_seqlens_k = cu_seqlens + + max_seqlen_q = attention_mask.shape[-1] + max_seqlen_k = attention_mask.shape[-1] + + q_packed = ( + q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0) + ) + kv_packed = ( + kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)) + .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]) + .unsqueeze(0) + ) + + return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + + def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 + assert inference_params is not None, "inference_params is required for inference" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + attention_mask = inference_params.attention_mask + sequence_len_offset = inference_params.sequence_len_offset + batch_size = x.shape[0] + + # wqkv, output: q, kv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + + q = qkv[:, :, 0].squeeze(2) + kv = qkv[:, :, 1:] + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + kv = torch.stack([k, v], dim=2) + + # rotary embedding, output: q, kv + # q shape: [bsz, nheads, head_dim] + # kv shape: [bsz, seqlen, 2, nheads, head_dim] + if self.use_dynamic_ntk_rope: + # update kv cache fisrt when enable dynamic ntk rope. + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + if sequence_len_offset != 0: + if sequence_len_offset > self.max_position_embeddings: + logger.warning( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + + if self.rotary_emb_dim > 0: + q = self.rotary_emb( + q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved + ) + k = kv[:, :, 0].squeeze(2) + self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True + ) # in-place is important + else: + if self.rotary_emb_dim > 0: + q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved) + k = kv[:, :, 0].squeeze(2) + self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True + ) # in-place is important + else: + assert self.rotary_emb_dim > 0, "You should use rotary_emb." + + k, v = kv[:, :, 0].squeeze(2), kv[:, :, 1].squeeze(2) + + if attention_mask is None: + q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=sequence_len_offset, cache_type="key", interleaved=self.interleaved) + else: + if sequence_len_offset == 0: + q = self.rotary_emb( + q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + k = self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + else: + if sequence_len_offset > self.max_position_embeddings: + logger.warning( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + + empties = attention_mask[..., -1].sum(dim=-1) + indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties + indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + # TODO To fit flash_attn apis, we rearrange q&k to pack them here and + # calculate rope for this batch input. Waiting to be optimized + q = rearrange(q, "b s h d -> s b h d", d=self.head_dim) # pack input + k = rearrange(k, "b s h d -> s b h d", d=self.head_dim) + q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + q = rearrange(q, "s b h d -> b s h d", d=self.head_dim) # unpack + k = rearrange(k, "s b h d -> b s h d", d=self.head_dim) + + kv = torch.stack([k, v], dim=2) + # update kv cache after rotary embedding when disable dynamic ntk rope. + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + # self-attention + if attention_mask is None: + context = self.inner_cross_attn(q, kv) + else: + if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = attention_mask[:, None, ...] + attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1) + + output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh)) + output = output.to(x.dtype) + + context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output) + else: + attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + + # wo + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + + class MHA(nn.Module): """ Multi-head self-attention and cross-attention. diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 56ebcfbe..1ecb7ea0 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -235,6 +235,118 @@ def get_megatron_flops( return tflops +def get_megatron_flops_mla( + elapsed_time_per_iter, + checkpoint=False, + seq_len=2048, + hidden_size=12, + num_layers=32, + vocab_size=12, + global_batch_size=4, + global_world_size=1, + mlp_ratio=4, + use_swiglu=True, + # moe_kwargs + topk=1, + num_experts=0, + moe_mlp_ratio=0.5, + first_k_dense=-1, + num_shared_experts=0, + # mla_kwargs + num_heads=32, + v_head_dim=128, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + q_lora_rank=1536, + kv_lora_rank=512, +): + """ + Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf + """ + + checkpoint_activations_factor = 4 if checkpoint else 3 + + if use_swiglu: + mlp_ratio = mlp_ratio * 3 / 2 + moe_mlp_ratio = moe_mlp_ratio * 3 / 2 + + # first k dense / dense model + # sum=2*3*(b*s*mlp_ratio*d^2) = 4*mlp_ratio *bs* d^2 + dense_mlp_flops = mlp_ratio * 4 * global_batch_size * seq_len * hidden_size**2 + + if num_experts > 0: + # moe + # total tokens: b*s; processed by E experts, with topk + per_expert_token = global_batch_size * seq_len * topk / num_experts + per_token_flops = moe_mlp_ratio * 4 * hidden_size**2 + moe_flops = per_expert_token * per_token_flops * num_experts + + if num_shared_experts > 0: + shared_mlp_flops = moe_mlp_ratio * num_shared_experts * 4 * global_batch_size * seq_len * hidden_size**2 + moe_flops += shared_mlp_flops + else: + moe_flops = 0 + first_k_dense = num_layers + + global_tokens = global_batch_size * seq_len + if q_lora_rank is not None or kv_lora_rank is not None: + q_head_dim = qk_nope_head_dim + qk_rope_head_dim + q_out_dim = q_head_dim * num_heads + else: + q_out_dim = hidden_size + if q_lora_rank is not None: + # mla + ## q: (bs, d) @(d, q_lora) -> (bs, q_lora);@(q_lora, q_out) ->(bs, q_out) + q_flops = 2 * (global_tokens * hidden_size * q_lora_rank + global_tokens * q_out_dim * q_lora_rank) + else: + # q: (bs, d) @(d, q_out_dim) -> (bs, q_out_dim) + q_flops = 2 * global_tokens * hidden_size * q_out_dim + + if kv_lora_rank is not None: + # kv: + ## (bs, d) @(d, kv_a_out) -> (bs, kv_a_out) + ## (bs, kv_lora_rank) @(kv_lora, kv_b_out) -> (bs, kv_b_out) + kv_a_out = kv_lora_rank + qk_rope_head_dim + kv_b_out = num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim) + + kv_flops = 2 * (global_tokens * kv_a_out * hidden_size + global_tokens * kv_b_out * kv_lora_rank) + + ## (bs, d) @(d, v_dim) -> (bs, v_dim) + v_dim = num_heads * v_head_dim + attn_out_flops = 2 * global_tokens * v_dim * hidden_size + else: + kv_flops = 4 * global_tokens * hidden_size**2 + attn_out_flops = 2 * global_tokens * hidden_size**2 + + qkv_flops = kv_flops + q_flops + + # attn: 2*2*bds**2 + ## (b, nh, s, hd) @(b, nh, hd, s) -> (b, nh, s, s) : b*nh*s**2*hd -> b*d*s**2 + ## (b, nh, s, s) @ (b, nh, s, hd) -> (b, nh, s, hd): b*s*d*s -> b*d*s**2 + if q_lora_rank is not None: + attn_hidden_size = num_heads * q_head_dim + attn_flops = 4 * global_batch_size * seq_len**2 * attn_hidden_size + else: + attn_flops = 4 * global_batch_size * seq_len**2 * hidden_size + + # vocab + vocab_flops = 6 * global_batch_size * seq_len * hidden_size * vocab_size + + flops_per_iteration = ( + checkpoint_activations_factor + * ( + dense_mlp_flops * first_k_dense + + (num_layers - first_k_dense) * moe_flops + + num_layers * (qkv_flops + attn_flops + attn_out_flops) + ) + + vocab_flops + ) + + tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12)) + + return tflops + + def enable_pytorch_expandable_segments(): if torch.__version__ >= "2.1.0" and AcceleratorType.GPU == internlm_accelerator.get_accelerator_backend(): _expandable_segments_conf = "expandable_segments:True"