diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml index f36c0566..28b0e4f1 100644 --- a/.github/workflows/e2e_test.yaml +++ b/.github/workflows/e2e_test.yaml @@ -1,5 +1,5 @@ name: e2e-tests -on: +on: pull_request: branches: - "develop" @@ -232,24 +232,4 @@ jobs: jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_llama2" ./tests/test_training/test_loss.py exit_code=$? - sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname - - training_internlm2: - strategy: - matrix: - runner: [t_cluster] - runs-on: ${{ matrix.runner }} - timeout-minutes: 20 - steps: - - name: mask env - run: | - echo "::add-mask::${{env.WORKSPACE_PREFIX}}" - echo "::add-mask::$path_prefix" - - uses: actions/checkout@v3 - - name: training_internlm2_T - run: | - source activate ${evo_env_torch21_flash2} - jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_internlm2" ./tests/test_training/test_loss.py - exit_code=$? - sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname \ No newline at end of file diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 2698a82f..de99f917 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -1,5 +1,5 @@ JOB_NAME = "7b_train" -# model_type = "INTERNLM2_PUBLIC" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False VOCAB_SIZE = 103168 @@ -31,7 +31,7 @@ # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined # load function such as "llama" load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering @@ -145,7 +145,7 @@ parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, - # no_bias=True, + no_bias=True, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" @@ -188,17 +188,17 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. sequence_2D (dict): 1. enable: bool, whether enable the 2D sequence parallel or not. - 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). + 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). head_size * context_size should be equal tensor size. 3. context_size: int, the parallel degree of context parallelism. head_size * context_size should be equal tensor size. 4. window_size: int, the sliding window size in context parallelism. 5. device_placement_strategy: dict, - head_first: bool, if `True`, ranks of the same head parallel group are + head_first: bool, if `True`, ranks of the same head parallel group are given high priority for colocation on the same node; if `False`, ranks of the same context parallel group are given high priority for colocation on the same node; - interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could + interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could interleaved the ranks in the same window to make full use of NIC as much as possible. """ parallel = dict( diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index d23cae63..dde4bc52 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -1,6 +1,7 @@ # Copyright (c) InternLM. All rights reserved. from internlm.model.modeling_internlm import InternLM1 +from internlm.model.modeling_internlm2 import InternLM2 from internlm.model.modeling_llama import Llama2 from internlm.utils.logger import get_logger @@ -9,4 +10,5 @@ LOAD_FUNC_DICT = { "llama": Llama2.load_llama_pretrained_weights, "internlm_test": InternLM1.load_internlm_with_dynamic_parallel_size, + "internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size, } diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index a4389b63..fedd27c4 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -1,6 +1,7 @@ # Copyright (c) InternLM. All rights reserved. import math import os +from functools import reduce from typing import Optional import torch @@ -11,6 +12,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.shard import partition_uniform from internlm.initialize.initialize_tensor import ( normal_, scaled_init_method_normal, @@ -26,6 +28,7 @@ from internlm.model.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, + get_parallel_size_from_file, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger @@ -576,6 +579,196 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: internlm_accelerator.empty_cache() + @staticmethod + def load_internlm2_with_dynamic_parallel_size(folder, model): + """Load InternLM2 with dynamic parallel size.""" + assert folder is not None, "Please specify the folder of the pretrained model" + assert gpc.config.model_type in ["INTERNLM2_PUBLIC"], "dynamic_parallel is only for INTERNLM2_PUBLIC" + + fns = get_fns(folder) + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612 + + tp = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + assert old_tp % tp == 0 or tp % old_tp == 0, ( + f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in " + f"checkpoint and {tp} in current config" + ) + + correspond_tps = [] + + if old_tp <= tp: + correspond_tps.append(tp_rank // (tp // old_tp)) + ratio = tp // old_tp + rank = tp_rank % ratio + else: + for i in range(old_tp // tp): + correspond_tps.append(tp_rank * (old_tp // tp) + i) + rank = 0 + ratio = 1 + + current_states = {} + + pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612 + + assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary" + + old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1) + + for idx, parts in enumerate(old_pp_partition): + start, end = parts[0] + if model.last_layer <= start or model.first_layer >= end: + continue + tmp_states = {} + + for correspond_tp in correspond_tps: + model_name = f"model_tp{correspond_tp}_pp{idx}.pt" + states = llm_load(os.path.join(folder, model_name), map_location="cpu") + states = {k.replace("model.", ""): v for k, v in states.items()} + for i in range(start, end): + if i >= model.last_layer: + break + if i < model.first_layer: + continue + + for name in list(states.keys()): + if f".{i-start}." in name: + to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.") + + if gpc.config.model_type == "INTERNLM2_PUBLIC": + if "norm" in name: + tmp_states[to_name] = [states.pop(name)] + elif any(x in name for x in ("wo", "w2")): + tmp_states[to_name] = tmp_states.get(to_name, []) + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank]) + elif any(x in name for x in ("w1", "w3")): + tmp_states[to_name] = tmp_states.get(to_name, []) + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) + elif any(x in name for x in ("wqkv",)): + tmp_states[to_name] = tmp_states.get(to_name, []) + if tp > gpc.config.model.num_kv_attention_heads: + assert old_tp <= gpc.config.model.num_kv_attention_heads, ( + f"`old_tp ({old_tp}) => tp ({tp})` is not supported. " + "At least one of `tp` and `old_tp` should be less than or " + "equal to `num_kv_attention_heads`" + ) + # Suitable for cases where the num_kv_attention_head is small, + # but you want to have a large TP Size + q_per_kv = ( + gpc.config.model.num_attention_heads + // gpc.config.model.num_kv_attention_heads + ) + head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads + index = torch.concat( + ( + torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio], + torch.tensor([q_per_kv, q_per_kv + 1]), + ) + ) + index = index + (q_per_kv + 2) * (tp_rank // ratio) + index = index % ( + (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp) + ) + index = index * head_dim + index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( + index.shape[0] + ) + tmp_states[to_name].append( + torch.index_select(states.pop(name), 0, index.to(torch.int32)) + ) + else: + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) + else: + raise KeyError(f"Unknown key {name}.") + + else: + assert False, "unsupported model type" + + if "tok_embeddings.weight" in states and model.first_layer == 0: + tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", []) + tmp_states["tok_embeddings.weight"].append( + states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank] + ) + if "output.weight" in states and model.last_layer == gpc.config.model.num_layers: + tmp_states["norm.weight"] = [states["norm.weight"]] + tmp_states["output.weight"] = tmp_states.get("output.weight", []) + tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank]) + + states = {} + + for name in list(tmp_states.keys()): + data = tmp_states.pop(name) + if len(data) == 1: + current_states[name] = data[0] + else: + current_states[name] = torch.concat( + data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0 + ) + # Merge copied kv heads + if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads: + assert ( + tp <= gpc.config.model.num_kv_attention_heads + ), "new_tp should be less than or equal to num_kv_attention_heads" + head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads + q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads + copied_times = old_tp // gpc.config.model.num_kv_attention_heads + cur_q_per_kv = q_per_kv // copied_times + + # pylint: disable=all + def duplicate_kv_index(i): + if i % (cur_q_per_kv + 2) >= cur_q_per_kv: + return i + else: + return -100 + + def unique_kv_index(i): + if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv: + return i + else: + return -100 + + # pylint: enable=all + + # Verify + duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] + duplicate_index = [i for i in duplicate_index if i != -100] + duplicate_index = _duplicate_index = torch.tensor(duplicate_index) + for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): + duplicate_index = torch.concat( + (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0 + ) + duplicate_kv = [] + for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1): + index = index.reshape(-1) * head_dim + index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0]) + duplicate_kv.append(torch.index_select(current_states[name], 0, index)) + assert reduce( + lambda x, y: x and y, + [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]], + ), "Copied kv heads are not equal after training!" + + # Merge + unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] + unique_index = [i for i in unique_index if i != -100] + unique_index = _unique_index = torch.tensor(unique_index) + for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): + unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0) + unique_index = unique_index * head_dim + unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( + unique_index.shape[0] + ) + current_states[name] = torch.index_select(current_states[name], 0, unique_index) + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True): diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 6e74d6b6..e51e5897 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -99,12 +99,12 @@ def __init__( self.w1 = new_linear( "w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) - self.w2 = new_linear( - "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert - ) self.w3 = new_linear( "w3", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) + self.w2 = new_linear( + "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert + ) def forward(self, x): if not self.mlp_layer_fusion: @@ -177,10 +177,10 @@ def __init__( backend=backend, is_expert=is_expert, ) - self.w2 = new_linear( - "grouped_w2", + self.w3 = new_linear( + "grouped_w3", + in_features, hidden_features, - out_features, bias, device=device, dtype=dtype, @@ -188,10 +188,10 @@ def __init__( backend=backend, is_expert=is_expert, ) - self.w3 = new_linear( - "grouped_w3", - in_features, + self.w2 = new_linear( + "grouped_w2", hidden_features, + out_features, bias, device=device, dtype=dtype, diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index d0a668c8..c2639569 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -928,7 +928,7 @@ def _qkv_without_cu_seqlens(self, qkv, softmax_scale=None, causal=None, key_padd # TODO: more unified interface dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -944,7 +944,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -960,7 +960,7 @@ def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, ke attn_type, op = _select_attn_op(AttnOpType.FixedLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if (attn_type is AttnType.Torch and key_padding_mask is not None) else () + extra_args = (key_padding_mask,) if (attn_type is AttnType.Torch and key_padding_mask is not None) else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -984,7 +984,7 @@ def _qkv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenQKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(qkv, cu_seqlens, max_seqlen, dropout, softmax_scale, causal, *extra_args) @@ -1007,7 +1007,7 @@ def _q_kv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1033,7 +1033,7 @@ def _q_k_v_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1088,7 +1088,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(q, kv, dropout, softmax_scale, causal, *extra_args) @@ -1100,7 +1100,7 @@ def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, ke attn_type, op = _select_attn_op(AttnOpType.FixedLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(q, k, v, dropout, softmax_scale, causal, *extra_args) @@ -1123,7 +1123,7 @@ def _q_kv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1149,7 +1149,7 @@ def _q_k_v_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args diff --git a/internlm/model/utils.py b/internlm/model/utils.py index e3ebf44d..7c974abe 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -6,9 +6,12 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.mha import MHA +from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load from internlm.utils.utils import TensorParallelMode +logger = get_logger(__file__) + def internlm1_mha_pre_load_convert( model: MHA, state_dict: Dict, prefix: str, *args, **kwargs # pylint: disable=W0613 @@ -138,3 +141,20 @@ def merge_pp_src_states(states): layer_shift += _layer_shift + 1 merged_states.append(shifted_state) return merged_states + + +def get_parallel_size_from_file(fns, suffix=None): + model_fns, old_tp, old_pp = [], -1, -1 + for fn in fns: + # filter with `_t` is for avoiding conflict with model_config.py + + if fn.startswith("model_t"): + if (suffix and fn.endswith(suffix)) or (suffix is None and not fn.endswith("md5")): + model_fns.append(fn) + _, tp, pp = os.path.splitext(fn)[0].split("_") + old_tp = max(old_tp, int(tp[2:]) + 1) + old_pp = max(old_pp, int(pp[2:]) + 1) + + assert old_tp > 0 and old_pp > 0, f"ckpt with tp:{old_tp} and pp:{old_pp} is illegal" + model_fns.sort() + return model_fns, old_tp, old_pp diff --git a/tests/test_training/7B_check_acc.py b/tests/test_training/7B_check_acc.py index 3b727d7c..cb3902bc 100644 --- a/tests/test_training/7B_check_acc.py +++ b/tests/test_training/7B_check_acc.py @@ -1,16 +1,20 @@ import os -JOB_NAME = "7b_train" +JOB_NAME = "7b_internlm2_train" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False +VOCAB_SIZE = 92544 SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 -MLP_RATIO = 8 / 3 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 3.5 NUM_LAYER = 32 -VOCAB_SIZE = 103168 -MODEL_ONLY_FOLDER = os.path.join(os.environ["share_path"], "quailty_assurance/7B_model_weights_ckpt/init") +MODEL_ONLY_FOLDER = os.path.join( + os.environ["share_path"], "quailty_assurance/7B_internlm2_init_dp=2_tp=2_pp=2_ckpt/init" +) # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' # SAVE_CKPT_FOLDER = "local:llm_ckpts_0925_9" @@ -121,7 +125,8 @@ ) model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + checkpoint=False, + num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, @@ -129,13 +134,22 @@ parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, + no_bias=True, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ zero1 parallel: @@ -150,9 +164,9 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=dict(size=8), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), + zero1=dict(size=-1), + tensor=dict(size=2, mode="mtp"), + pipeline=dict(size=2, interleaved_overlap=True), weight=dict(size=1, overlap=True), ) @@ -165,5 +179,30 @@ enable_feishu_alert=DO_ALERT, feishu_alert_address=None, # feishu webhook to send alert message light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, ), ) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) + +enable_tb = False diff --git a/tests/test_training/7B_check_init.py b/tests/test_training/7B_check_init.py index 6f72c7d7..03107d02 100644 --- a/tests/test_training/7B_check_init.py +++ b/tests/test_training/7B_check_init.py @@ -1,12 +1,14 @@ -JOB_NAME = "7b_train" +JOB_NAME = "7b_internlm2_train" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False +VOCAB_SIZE = 92544 SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 -MLP_RATIO = 8 / 3 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 3.5 NUM_LAYER = 32 -VOCAB_SIZE = 103168 CHECK_INIT = 1 @@ -128,7 +130,8 @@ ) model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + checkpoint=False, + num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, @@ -136,13 +139,22 @@ parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, + no_bias=True, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) parallel = dict( @@ -161,5 +173,30 @@ enable_feishu_alert=DO_ALERT, feishu_alert_address=None, # feishu webhook to send alert message light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, ), ) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) + +enable_tb = False diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index 69c9eb90..48b97bfa 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -56,7 +56,7 @@ checkpoint=True, num_attention_heads=32, embed_split_hidden=True, - vocab_size=103168, + vocab_size=92544, embed_grad_scale=1, parallel_output=False, hidden_size=4096, @@ -68,8 +68,9 @@ layer_norm_epsilon=1e-5, use_flash_attn=False, num_chunks=1, + no_bias=True, ), - model_type="INTERNLM", + model_type="INTERNLM2_PUBLIC", alert_address=None, monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), grad_scaler=dict( @@ -178,7 +179,7 @@ def train_check_output(args): optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) - train_dl, dataset_types = build_train_loader_with_data_type() + _, dataset_types = build_train_loader_with_data_type() metric = AccPerplex( device=get_current_device(), @@ -226,9 +227,9 @@ def train_check_output(args): if gpc.is_rank_for_log(): standard_output_with_fa = torch.load( - os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa.pt") + os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa_internlm2.pt") ) - tensor1 = standard_output_with_fa + tensor1 = standard_output_with_fa[0][0] tensor2 = output[0][0][0] if torch.equal(tensor1, tensor2): diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 4094c582..d1db7496 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -25,25 +25,26 @@ from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer -CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_sft.py") -INTERNLM1_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/model_ckpt") +CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_internlm2.py") +INTERNLM2_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss_pri/model_ckpt") TOTAL_STEPS = 10 LOSS_SPIKE_LIMIT = 1.5 LOSS_DEVIATION_LIMIT = 0.02 # dp_size = 4 BASELINE_LOSS_LIST = [ - 11.63298511505127, - 7.82645320892334, - 6.727725505828857, - 6.182029724121094, - 5.395882606506348, - 5.394383430480957, - 5.053952217102051, - 4.742049694061279, - 4.629276752471924, - 4.616517543792725, + 12.362918853759766, + 12.404379844665527, + 12.348219871520996, + 12.194982528686523, + 11.80469036102295, + 11.573806762695312, + 10.045475006103516, + 9.660882949829102, + 9.172087669372559, + 4.799427032470703, ] + cur_loss_list = [] internlm_accelerator = get_accelerator() @@ -59,7 +60,7 @@ def train( enable_sp: bool = False, save_ckpt: bool = False, load_ckpt: bool = False, - model_type: str = "INTERNLM", + model_type: str = "INTERNLM2_PUBLIC", optimizer_ver: str = "v1", pp_mode: str = "1F1B", ): @@ -67,24 +68,31 @@ def train( config = Config.from_file(CONFIG_FILE_PATH) # init setting - config.data.total_steps = TOTAL_STEPS + config.data.total_steps = 50000 config.data.fixed_random_dataset_seqlen = False - config.lr_scheduler.total_steps = TOTAL_STEPS + config.data.micro_num = 4 + config.data.micro_bsz = 2 + config.lr_scheduler.total_steps = config.data.total_steps config.model_type = model_type config.ckpt.load_ckpt_folder = None config.ckpt.load_ckpt_info = None config.ckpt.auto_resume = False - total_steps = config.data.total_steps + total_steps = TOTAL_STEPS skip_batches = config.data.skip_batches label_smoothing = config.loss.label_smoothing + config.parallel.zero1 = dict(size=-1) + config.parallel.tensor = dict(size=1, mode="mtp") + config.parallel.pipeline = dict(size=1, interleaved_overlap=True, mode="1f1b") + config.parallel.weight = dict(size=1, overlap=True) if optimizer_ver == "v2": config.hybrid_zero_optimizer.use_split_tensor_optim = True config.all_gather_size = 512 * 1024 * 1024 + config.model.checkpoint = True # update ckpt config - if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False: - config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test") + if model_type == "INTERNLM2_PUBLIC" and tp_mode != "isp" and interleaved is False: + config.ckpt.load_ckpt_info = dict(path=INTERNLM2_CKPT_PATH, content=("model",), ckpt_type="internlm2_test") if save_ckpt: config.ckpt.enable_save_ckpt = True @@ -213,7 +221,7 @@ def train( train_iter = iter(train_dl) - if model_type == "INTERNLM": + if model_type == "INTERNLM2_PUBLIC": data_path = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/data_batch_4DP") data_batch = torch.load(f"{data_path}/{gpc.get_local_rank(ParallelMode.DATA)}_data_batch.pt") @@ -222,7 +230,7 @@ def train( empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) timer("one-batch").start() - if model_type == "INTERNLM": + if model_type == "INTERNLM2_PUBLIC": if batch_count >= 10: batch = data_batch[batch_count - 10] else: @@ -296,7 +304,6 @@ def check_loss_spike(): def check_loss_accuracy(): if gpc.is_rank_for_log(): - print(f"cur_loss_list:{cur_loss_list}", flush=True) for cur, target in zip(cur_loss_list, BASELINE_LOSS_LIST): assert ( abs(cur - target) < LOSS_DEVIATION_LIMIT @@ -464,16 +471,16 @@ def test_training_with_isp(): global CONFIG_FILE_PATH, BASELINE_LOSS_LIST CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ - 11.595988273620605, - 7.988386154174805, - 6.821506500244141, - 6.2768449783325195, - 5.478013515472412, - 5.4622697830200195, - 5.162247180938721, - 4.854615211486816, - 4.744818210601807, - 4.75523567199707, + 12.225811004638672, + 12.103824615478516, + 12.223844528198242, + 11.87704849243164, + 11.651590347290039, + 11.629219055175781, + 10.242591857910156, + 9.768388748168945, + 9.330610275268555, + 5.505439758300781, ] # model training @@ -516,12 +523,3 @@ def test_training_llama2(): CONFIG_FILE_PATH = "./configs/7B_llama2.py" train(dp_size=8, model_type="LLAMA2") - - -@pytest.mark.training_internlm2 -def test_training_internlm2(): - # update config file - global CONFIG_FILE_PATH - CONFIG_FILE_PATH = "./configs/7B_internlm2.py" - - train(dp_size=8, model_type="INTERNLM2_PUBLIC") diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index b33cf4c3..7926bae5 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -20,7 +20,7 @@ from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState, Trainer # noqa: E402 +from internlm.core.trainer import Trainer, TrainState # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -60,6 +60,7 @@ def check_model_weights(model, ckpt_path, total_equal=False): + model = model.model model1_dict = torch.load(ckpt_path, map_location="cuda") model2_dict = model.state_dict() @@ -214,13 +215,14 @@ def main(args): # check model init weights if hasattr(gpc.config, "CHECK_INIT") and gpc.config.CHECK_INIT == 1: ckpt_name = ( - f"model_dp{gpc.get_local_rank(ParallelMode.DATA)}" + f"model" f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}" f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" ) - ckpt_path = os.path.join(os.environ["share_path"], "quailty_assurance/7B_init_dp=2_tp=2_pp=2_ckpt", ckpt_name) + ckpt_path = os.path.join( + os.environ["share_path"], "quailty_assurance/7B_internlm2_init_dp=2_tp=2_pp=2_ckpt/init", ckpt_name + ) check_model_weights(model, ckpt_path, total_equal=True) - with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): @@ -327,12 +329,17 @@ def main(args): ) # check model weights - if gpc.is_rank_for_log() and batch_count > 0 and batch_count % 100 == 0: + if batch_count > 0 and batch_count % 100 == 0: + ckpt_name = ( + f"model" + f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}" + f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" + ) ckpt_path = os.path.join( os.environ["share_path"], - "quailty_assurance/7B_model_weights_ckpt", + "quailty_assurance/7B_internlm2_init_dp=2_tp=2_pp=2_ckpt", str(batch_count), - "model_tp0_pp0.pt", + ckpt_name, ) check_model_weights(model, ckpt_path)