diff --git a/README-ja-JP.md b/README-ja-JP.md index a9327ecc..009f1d27 100644 --- a/README-ja-JP.md +++ b/README-ja-JP.md @@ -2,7 +2,7 @@
- + [![Documentation Status](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest) [![license](./doc/imgs/license.svg)](./LICENSE) @@ -143,6 +143,10 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - diff --git a/README-zh-Hans.md b/README-zh-Hans.md index ea38b04e..652cd8db 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -2,7 +2,7 @@
- + [![使用文档](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest) [![license](./doc/imgs/license.svg)](./LICENSE) @@ -143,6 +143,10 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - diff --git a/README.md b/README.md index 9a61d573..f04f6f14 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@
- + [![Documentation Status](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest) [![license](./doc/imgs/license.svg)](./LICENSE) @@ -143,6 +143,10 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar diff --git a/configs/7B_baichuan2.py b/configs/7B_baichuan2.py new file mode 100644 index 00000000..fdc1b0ab --- /dev/null +++ b/configs/7B_baichuan2.py @@ -0,0 +1,225 @@ +JOB_NAME = "7b_baichuan2_train" +model_type = "BAICHUAN2" +DO_ALERT = False + +VOCAB_SIZE = 125696 +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 32 + + +MODEL_ONLY_FOLDER = "local:llm_ckpts_baichuan2/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts_baichuan2" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=False, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=20, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False +model = dict( + checkpoint=False, + num_chunks=1, + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + no_bias=True, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-6, + use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py new file mode 100644 index 00000000..6ef8c99b --- /dev/null +++ b/configs/7B_gemma.py @@ -0,0 +1,232 @@ +JOB_NAME = "7b_gemma_train" +model_type = "GEMMA" +DO_ALERT = False + +VOCAB_SIZE = 256000 +SEQ_LEN = 2048 +HIDDEN_SIZE = 3072 +NUM_ATTENTION_HEAD = 16 +NUM_KV_ATTENTION_HEAD = 16 +HEAD_DIM = 256 +MLP_RATIO = 8 +NUM_LAYER = 28 + + +MODEL_ONLY_FOLDER = "local:llm_ckpts_gemma/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts_gemma" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=False, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=20, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False +model = dict( + checkpoint=False, + num_chunks=1, + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + max_position_embeddings=8192, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + no_bias=True, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + add_unit_offset=True, + norm_type="rmsnorm", + layer_norm_epsilon=1e-6, + head_dim=HEAD_DIM, + use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, + use_swiglu=False, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py new file mode 100644 index 00000000..07e572d2 --- /dev/null +++ b/configs/7B_qwen2.py @@ -0,0 +1,232 @@ +JOB_NAME = "7b_qwen2_train" +model_type = "QWEN2" +DO_ALERT = False + +VOCAB_SIZE = 152064 +SEQ_LEN = 2048 +HIDDEN_SIZE = 3584 +NUM_ATTENTION_HEAD = 28 +NUM_KV_ATTENTION_HEAD = 4 +MLP_RATIO = 5.25 +NUM_LAYER = 28 + + +MODEL_ONLY_FOLDER = "local:llm_ckpts_qwen2/xxxx/" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts_qwen2" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=False, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=20, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False +model = dict( + checkpoint=False, + num_chunks=1, + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + qkv_bias=True, + o_bias=False, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-6, + use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, + rope_base=1000000, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index e7f581dc..6937b8e4 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -44,7 +44,7 @@ def get_tokenized_train_loader_items(data_cfg): if data_cfg.get("is_multimodal", False): image_token_size = int(data_cfg.image_size // data_cfg.patch_size) ** 2 train_ds = RandomDatasetMultimodal( - num_samples=100000, + num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len, image_size=data_cfg.image_size, image_token_size=image_token_size, @@ -54,7 +54,9 @@ def get_tokenized_train_loader_items(data_cfg): ) else: train_ds = RandomDataset( - num_samples=1000000, max_len=data_cfg.seq_len, fixed_seqlen=data_cfg.fixed_random_dataset_seqlen + num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, + max_len=data_cfg.seq_len, + fixed_seqlen=data_cfg.fixed_random_dataset_seqlen, ) if data_cfg.pack_sample_into_one: diff --git a/internlm/model/modeling_baichuan2.py b/internlm/model/modeling_baichuan2.py new file mode 100644 index 00000000..8674b811 --- /dev/null +++ b/internlm/model/modeling_baichuan2.py @@ -0,0 +1,639 @@ +# Copyright (c) InternLM. All rights reserved. +import math +import os +from typing import Optional + +import torch +from einops import rearrange +from torch import nn +from tqdm import tqdm + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.base_model import BaseModel +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import MHA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.utils import ( + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, +) +from internlm.solver.activation_checkpoint import activation_checkpoint +from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) + +internlm_accelerator = get_accelerator() +logger = get_logger(__file__) + + +class Baichuan2Decoder(nn.Module): + """ + 1D Packed Flash Llama Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.006 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + attn_wqkv_init_std: float = 0.006, + attn_other_init_std: float = 0.0015, + ffn_uplayer_init_std: float = 0.006, + ffn_other_init_std: float = 0.0015, + init_type: str = "normal", + rope_base: int = 10000, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + max_position_embeddings: int = 2048, + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + head_dim = hidden_size // num_attention_heads + + self.attention = MHA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + bias=not no_bias, + dropout=attn_drop_rate, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rope_base=rope_base, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + enable_qkv_fusion=True, + out_bias=False, + ) + + self.dropout1 = nn.Dropout(drop_rate) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.feed_forward = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "gelu", + ) + + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_swiglu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + # candidate: w1, w3, fused_w1_w3 + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward(self, hidden_states, residual=None, **kwargs): + if self.checkpoint and self.training: + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) + else: + return self._forward(hidden_states, residual, **kwargs) + + def _forward(self, hidden_states, residual, *args, **kwargs): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + + return hidden_states + residual + else: + assert residual is None + + mixer_out = self.attention(hidden_states, **kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + mlp_out = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states + + +class Baichuan2(BaseModel): + """ + 1D Packed Flash Llama. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 1.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + embedding_init_std (float): std used to init embedding weight. 0.0052 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.006 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.0052 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + checkpoint: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + embedding_init_std: float = 0.0052, + attn_wqkv_init_std: float = 0.006, + attn_other_init_std: float = 0.0015, + ffn_uplayer_init_std: float = 0.006, + ffn_other_init_std: float = 0.0015, + out_head_init_std: float = 0.0052, + init_type: str = "normal", + norm_head: bool = False, + rope_base: int = 10000, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + ): + super().__init__() + + checkpoint_layer_num = int(num_layers * checkpoint) + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output + + if first: + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + + for _, param in self.tok_embeddings.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + + self.layers = nn.ModuleList( + [ + Baichuan2Decoder( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + max_position_embeddings=max_position_embeddings, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + qk_interleaved=qk_interleaved, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_base=rope_base, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.output = new_linear( + name="output", + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + bias=False, + device=device, + dtype=dtype, + is_reward=is_reward, + weight_scale=embed_grad_scale, + norm_head=norm_head, + ) + + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + def forward(self, hidden_states=None, input_ids=None, **kwargs): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "tok_embeddings") and input_ids is not None: + hidden_states = self.tok_embeddings(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + + for _, block in enumerate(self.layers): + hidden_states = block(hidden_states, residual=None, **kwargs) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module) -> None: + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["tok_embeddings.weight"] = torch.chunk( + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.W_pack.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.out_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + if len(state_dict) > 0: + logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.") + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + def permute(qkv, num_heads, num_kv_heads, head_dim, qk_interleaved=False): + if not qk_interleaved: + return qkv + q_per_kv = num_heads // num_kv_heads + qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim) + q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :] + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + qkv = torch.cat((q, k, v), dim=2) + qkv = rearrange(qkv, "o g n i -> o (g n i)").T + return qkv + + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + # load states + states, num_shards = Baichuan2.load_sharded_states(src) + + # convert state_dict + state_dict = {} + embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, ffn norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn + state_dict[f"model.layers.{layer_i}.self_attn.W_pack.weight"] = permute( + torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0), + num_heads=model_config["num_attention_heads"], + # num_kv_attention_heads equals to num_attention_heads in MHA + num_kv_heads=model_config["num_attention_heads"], + head_dim=model_config["hidden_size"] // model_config["num_attention_heads"], + qk_interleaved=model_config.get("qk_interleaved", False), + ) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.out_proj.weight"] for i in range(num_shards)], dim=row_dim + ) + # ffn + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + # embedding, output + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [states[i][embedding_key] for i in range(num_shards)], dim=embed_concat_dim + ), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/modeling_gemma.py b/internlm/model/modeling_gemma.py new file mode 100644 index 00000000..a43843a8 --- /dev/null +++ b/internlm/model/modeling_gemma.py @@ -0,0 +1,752 @@ +# Copyright (c) InternLM. All rights reserved. +import math +import os +from typing import Optional + +import torch +from torch import nn +from tqdm import tqdm + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.base_model import BaseModel +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.utils import ( + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, +) +from internlm.solver.activation_checkpoint import activation_checkpoint +from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) + +try: + from flash_attn.modules.mlp import ParallelFusedMLP +except ImportError: + pass + +internlm_accelerator = get_accelerator() +logger = get_logger(__file__) + + +class GemmaDecoder(nn.Module): + """ + 1D Packed Flash Llama Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default. + use_glu (bool): Whether to use glu. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], + "mtp" by default. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + head_dim: int = None, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + add_unit_offset: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_glu: bool = True, + use_swiglu: bool = True, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + tp_mode: str = "mtp", + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + if not head_dim: + head_dim = hidden_size // num_attention_heads + + self.attention = GQA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + head_dim=head_dim, + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + bias=not no_bias, + rope_base=rope_base, + enable_qkv_fusion=False, + ) + + self.dropout1 = nn.Dropout(drop_rate) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm( + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset + ) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) + + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + parallel_mode = ParallelMode.WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR + + if use_glu: + self.feed_forward = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + activation_type="swiglu" if use_swiglu else "gelu", + ) + else: + self.feed_forward = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(parallel_mode), + bias1=False, + bias2=False, + sequence_parallel=sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + + self.use_glu = use_glu + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_glu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward(self, hidden_states, residual=None, **kwargs): + if self.checkpoint and self.training: + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) + else: + return self._forward(hidden_states, residual, **kwargs) + + def _forward(self, hidden_states, residual, *args, **kwargs): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + + return hidden_states + residual + else: + assert residual is None + + mixer_out = self.attention(hidden_states, **kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + mlp_out = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states + + +class Gemma(BaseModel): + """ + 1D Packed Flash Llama. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 1.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default. + use_glu (bool): Whether to use glu. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + head_dim: int = None, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + checkpoint: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + add_unit_offset: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_glu: bool = True, + use_swiglu: bool = False, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + extra_pred_tokens: int = 0, + rope_base: int = 10000, + norm_head: bool = False, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + ): + super().__init__() + + checkpoint_layer_num = int(num_layers * checkpoint) + self.hidden_size = hidden_size + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output + self.tp_mode = "mtp" + if isinstance(gpc.config.parallel["tensor"], dict): + self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + + if first: + self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.embed_tokens.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + + self.layers = nn.ModuleList( + [ + GemmaDecoder( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + max_position_embeddings=max_position_embeddings, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + add_unit_offset=add_unit_offset, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_glu=use_glu, + use_swiglu=use_swiglu, + qk_interleaved=qk_interleaved, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_base=rope_base, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + tp_mode=self.tp_mode, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + self.norm = new_layer_norm( + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset + ) + + self.output = new_linear( + name="output", + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + bias=False, + device=device, + is_reward=is_reward, + dtype=dtype, + weight_scale=embed_grad_scale, + norm_head=norm_head, + ) + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + if extra_pred_tokens > 0: + self.extra_pred_tokens = extra_pred_tokens + assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" + self.extra_outputs = nn.ModuleList( + [ + new_linear( + name="output", + in_features=hidden_size, + out_features=vocab_size, + bias=False, + device=device, + is_reward=is_reward, + dtype=dtype, + weight_scale=embed_grad_scale, + norm_head=norm_head, + ) + for _ in range(self.extra_pred_tokens) + ] + ) + for _, param in self.extra_outputs.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + def forward(self, hidden_states=None, input_ids=None, **kwargs): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "embed_tokens"): + hidden_states = self.embed_tokens(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + hidden_states = hidden_states * (self.hidden_size**0.5) + + for _, block in enumerate(self.layers): + hidden_states = block(hidden_states, residual=None, **kwargs) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) + if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: + extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] + else: + extra_hidden_states_list = None + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + if extra_hidden_states_list is not None: + return (hidden_states, extra_hidden_states_list) + + return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module) -> None: + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["embed_tokens.weight"] = torch.chunk( + state_dict.get("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + if "lm_head.weight" in state_dict: + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), # we do not tie lm head with embedding + split_size, + dim=0, + )[local_rank] + state_dict.pop("model.embed_tokens.weight") + else: + new_state_dict["output.weight"] = torch.chunk( + # gemma model ties lm head with embedding in transformers implementation + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + if len(state_dict) > 0: + logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.") + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + + # load states + states, num_shards = Gemma.load_sharded_states(src) + + # convert state_dict + state_dict = {} + embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, mlp norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn wqkv weight and bias + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], + dim=0, + ) + # attn wo weight + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim + ) + + # mlp + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + # embedding, head + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + # Save the index as well + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/modeling_qwen2.py b/internlm/model/modeling_qwen2.py new file mode 100644 index 00000000..d3700baa --- /dev/null +++ b/internlm/model/modeling_qwen2.py @@ -0,0 +1,752 @@ +# Copyright (c) InternLM. All rights reserved. +import math +import os +from typing import Optional + +import torch +from torch import nn +from tqdm import tqdm + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.base_model import BaseModel +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import SWA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.utils import ( + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, +) +from internlm.solver.activation_checkpoint import activation_checkpoint +from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) + +internlm_accelerator = get_accelerator() +logger = get_logger(__file__) + + +class Qwen2Decoder(nn.Module): + """ + 1D Packed Flash Qwen Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + qkv_bias=True, + o_bias=False, + mlp_bias=False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_type: str = "normal", + rope_base: int = 10000, + rope_scaling_factor: float = 1.0, + use_sliding_window: bool = False, + sliding_window: int = None, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + scale_attn_weights: bool = False, # Qwen1 + use_logn_attn: bool = False, # Qwen1 + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + head_dim = hidden_size // num_attention_heads + + if scale_attn_weights: + softmax_scale = None + else: + softmax_scale = 1 / math.sqrt(head_dim) + self.attention = SWA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=softmax_scale, + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + qkv_bias=qkv_bias, + o_bias=o_bias, + rope_type=rope_type, + rope_base=rope_base, + rope_scaling_factor=rope_scaling_factor, + use_sliding_window=use_sliding_window, + sliding_window=sliding_window, + use_logn_attn=use_logn_attn, + ) + + self.dropout1 = nn.Dropout(drop_rate) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.feed_forward = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=mlp_bias, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + activation_type="swiglu" if use_swiglu else "gelu", + ) + + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_swiglu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + # candidate: w1, w3, fused_w1_w3 + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward(self, hidden_states, residual=None, **kwargs): + if self.checkpoint and self.training: + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) + else: + return self._forward(hidden_states, residual, **kwargs) + + def _forward(self, hidden_states, residual, *args, **kwargs): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + + return hidden_states + residual + else: + assert residual is None + + mixer_out = self.attention(hidden_states, **kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + mlp_out = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states + + +class Qwen2(BaseModel): + """ + 1D Packed Flash Qwen. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + checkpoint: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + qkv_bias=True, + o_bias=False, + mlp_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + extra_pred_tokens: int = 0, + rope_type: str = "normal", + rope_base: int = 10000, + rope_scaling_factor: float = 1.0, + use_sliding_window: bool = False, + max_window_layers: int = 0, + sliding_window: int = None, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + scale_attn_weights: bool = False, # Qwen1 + use_logn_attn: bool = False, # Qwen1 + ): + super().__init__() + + self.embed_grad_scale = embed_grad_scale + + checkpoint_layer_num = int(num_layers * checkpoint) + + if first: + self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.embed_tokens.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + + self.layers = nn.ModuleList( + [ + Qwen2Decoder( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + qkv_bias=qkv_bias, + o_bias=o_bias, + mlp_bias=mlp_bias, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + qk_interleaved=qk_interleaved, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_type=rope_type, + rope_base=rope_base, + rope_scaling_factor=rope_scaling_factor, + use_sliding_window=use_sliding_window and lid >= max_window_layers, + sliding_window=sliding_window, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + max_position_embeddings=max_position_embeddings, + scale_attn_weights=scale_attn_weights, + use_logn_attn=use_logn_attn, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.output = new_linear( + name="output", + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + bias=False, + device=device, + dtype=dtype, + is_reward=is_reward, + weight_scale=embed_grad_scale, + ) + + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + if extra_pred_tokens > 0: + self.extra_pred_tokens = extra_pred_tokens + assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" + self.extra_outputs = nn.ModuleList( + [ + new_linear( + name="output", + in_features=hidden_size, + out_features=vocab_size, + bias=False, + device=device, + dtype=dtype, + is_reward=is_reward, + weight_scale=embed_grad_scale, + ) + for _ in range(self.extra_pred_tokens) + ] + ) + for _, param in self.extra_outputs.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + self.parallel_output = parallel_output + + def forward(self, hidden_states=None, input_ids=None, **kwargs): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "embed_tokens"): + hidden_states = self.embed_tokens(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + + for _, block in enumerate(self.layers): + hidden_states = block( + hidden_states, + residual=None, + **kwargs, + ) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) + if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: + extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] + else: + extra_hidden_states_list = None + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + if extra_hidden_states_list is not None: + return (hidden_states, extra_hidden_states_list) + + return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module) -> None: + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["embed_tokens.weight"] = torch.chunk( + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wq.bias"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.bias"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.bias"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.bias"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.bias"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.bias"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + if len(state_dict) > 0: + logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.") + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + + # load states + states, num_shards = Qwen2.load_sharded_states(src) + + # convert state_dict + state_dict = {} + embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, mlp norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn wqkv weight and bias + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.bias"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.bias"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.bias"] for i in range(num_shards)], + dim=0, + ) + # attn wo weight + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim + ) + + # mlp + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + # embedding, head + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + # Save the index as well + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index cd8eaff2..8370606b 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import inspect import math from typing import Callable, Dict, Optional @@ -75,6 +76,7 @@ def __init__( dtype: Optional[torch.dtype] = None, qk_interleaved: Optional[bool] = True, enable_qkv_fusion: bool = True, + out_bias: bool = True, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -83,6 +85,7 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = self.embed_dim // num_heads + self.kv_dim = self.head_dim * num_heads # num_kv_heads equals to num_heads in MHA self.enable_qkv_fusion = enable_qkv_fusion self.use_dynamic_ntk_rope = use_dynamic_ntk_rope @@ -116,8 +119,8 @@ def __init__( self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - # output projection always have the bias (for now) - self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=True, **factory_kwargs) + # output projection always have the bias (for now) (except for baichuan2 model) + self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None @@ -355,6 +358,7 @@ def __init__( num_heads: int, num_kv_heads: int, max_position_embeddings: int = 2048, + head_dim: int = None, bias: bool = False, dropout: float = 0.0, softmax_scale: float = None, @@ -375,9 +379,15 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads + + if head_dim: + self.head_dim = head_dim + q_dim = head_dim * num_heads + else: + self.head_dim = self.embed_dim // num_heads + q_dim = embed_dim self.num_kv_heads = num_kv_heads self.q_per_kv = num_heads // num_kv_heads - self.head_dim = self.embed_dim // num_heads self.kv_dim = self.head_dim * num_kv_heads self.enable_qkv_fusion = enable_qkv_fusion @@ -405,7 +415,7 @@ def __init__( if enable_qkv_fusion: self.wqkv = new_linear("wqkv", embed_dim, embed_dim + 2 * self.kv_dim, bias, **factory_kwargs) else: - self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) + self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -416,7 +426,7 @@ def __init__( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) - self.wo = new_linear("wo", embed_dim, embed_dim, bias, **factory_kwargs) + self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None @@ -624,3 +634,337 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 # wo return self.wo(rearrange(context, "b s h d -> b s (h d)")) + + +try: + from flash_attn import flash_attn_func + + # flash_attn >= v2.3.0 + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except (ModuleNotFoundError, ImportError): + _flash_supports_window_size = False + + +class SWA(nn.Module): + """ + sliding window attention + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. + sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation. + bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and + output projection. True by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], + "mtp" by default. + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + qkv_bias: bool = True, + o_bias: bool = False, + max_position_embeddings: int = 2048, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + rope_type: str = "normal", + rope_base: int = 10000, + rope_scaling_factor: float = 1.0, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + use_sliding_window: bool = False, + sliding_window: int = None, + tp_mode: str = "mtp", + qk_interleaved: Optional[bool] = True, + use_logn_attn: bool = False, # Qwen1 + ) -> None: + assert embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + assert (not use_sliding_window) or ( + sliding_window is not None + ), "Must set `sliding windows` size when `use_sliding_window` is True." + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.head_dim = self.embed_dim // num_heads + self.num_kv_heads = num_kv_heads + self.kv_dim = self.head_dim * num_kv_heads + self.causal = causal + self.layer_idx = layer_idx + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.rotary_emb_dim = rotary_emb_dim + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.dtype = dtype + self.tp_mode = tp_mode + self.rope_type = rope_type + self.use_logn_attn = use_logn_attn + self.interleaved = qk_interleaved + + factory_kwargs = {"device": device, "dtype": dtype} + + assert self.use_dynamic_ntk_rope is False, "Not support dynamic ntk rope yet." + assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + + if self.rotary_emb_dim > 0: + self.rotary_emb = new_rotary_embedding( + self.rotary_emb_dim, + base=rope_base, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=rope_scaling_factor, + rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", + ) + + # notice here should change bias=True + self.wq = new_linear( + "wq", + embed_dim, + embed_dim, + qkv_bias, + **factory_kwargs, + ) + self.wk = new_linear( + "wk", + embed_dim, + self.kv_dim, + qkv_bias, + **factory_kwargs, + ) + self.wv = new_linear( + "wv", + embed_dim, + self.kv_dim, + qkv_bias, + **factory_kwargs, + ) + + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + self.inner_cross_attn_causal = causal + self.inner_cross_attn_softmax_scale = softmax_scale + self.inner_cross_attn_dropout = dropout + + self.wo = new_linear( + "wo", + embed_dim, + embed_dim, + o_bias, + **factory_kwargs, + ) + + def forward(self, x, inference_params=None, **kwargs): + if inference_params is None: + return self._training(x=x, **kwargs) + else: + return self._inference(x=x, inference_params=inference_params, **kwargs) + + def _training(self, x, **kwargs): + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b t (h d) -> b t h d", d=self.head_dim) + k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim) + v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim) + + kv_seq_len = k.size(0) + use_window_circumstance = ( + _flash_supports_window_size + and self.use_sliding_window + and self.sliding_window + and kv_seq_len > self.sliding_window + ) + + kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + + # rotary embedding + if self.rotary_emb_dim > 0: + indexes = kwargs.pop("indexes", 0) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", None) + + q = self.rotary_emb( + q, offsets=indexes, max_seqlen=max_seqlen_q, cache_type="query", interleaved=self.interleaved + ) + k = self.rotary_emb( + k, offsets=indexes, max_seqlen=max_seqlen_k, cache_type="key", interleaved=self.interleaved + ) + + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + + if use_window_circumstance: + kwargs["window_size"] = (self.sliding_window, 0) + + # self attention + context = self.inner_attn(q, kv, **kwargs) + + # wo + return self.wo(rearrange(context, "b s h d -> b s (h d)")) + + def _convert_unpacked_qkv_to_packed( + self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor + ): + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attention_mask.device), + attention_mask.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ).cumsum(dim=0, dtype=torch.int32) + + cu_seqlens_q = cu_seqlens + cu_seqlens_k = cu_seqlens + + max_seqlen_q = attention_mask.shape[-1] + max_seqlen_k = attention_mask.shape[-1] + + q_packed = ( + q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0) + ) + kv_packed = ( + kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)) + .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]) + .unsqueeze(0) + ) + + return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + + def _inference(self, x, inference_params=None, **kwargs): # pylint: disable=W0613 + assert inference_params is not None, "inference_params is required for inference" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + attention_mask = inference_params.attention_mask + sequence_len_offset = inference_params.sequence_len_offset + window_size = inference_params.window_size + + bsz = x.shape[0] + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + kv_seq_len = k.size(0) + use_window_circumstance = ( + _flash_supports_window_size + and self.use_sliding_window + and self.sliding_window + and kv_seq_len > self.sliding_window + ) + + assert self.rotary_emb_dim > 0 + if attention_mask is None: + raise NotImplementedError( + "You should make sure you are aware that you are changing the method of generating." + "According to your generation function instead of inference/seq_generator_module.py, " + "You may implement here for normal running." + ) + else: + if inference_params.sequence_len_offset == 0: + q = self.rotary_emb( + q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + k = self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + else: + empties = attention_mask[..., -1].sum(dim=-1) + indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties + indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + + kv = torch.stack([k, v], dim=2) + + if window_size is None or window_size > sequence_len_offset: + kv = update_kv_cache(kv, inference_params, self.layer_idx) + else: # window_size <= sequence_len_offset + assert kv.size(1) == 1, "update kv length more than 1" + + inference_params.key_value_memory_dict[self.layer_idx][ + :, inference_params.keep_first : inference_params.window_size - 1, ... + ] = inference_params.key_value_memory_dict[self.layer_idx][ + :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... + ].clone() + inference_params.real_sequence_len_offset = inference_params.sequence_len_offset + inference_params.sequence_len_offset = inference_params.window_size - 1 + + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + inference_params.sequence_len_offset = inference_params.real_sequence_len_offset + + # When using FP16, there is a high probability of NAN in the KV. + # Since NAN cannot be removed by multiplying with and 0, it needs + # to be removed manually here. + kv = torch.where(torch.isnan(kv), 0, kv) + + # attention + if attention_mask is None: + context = self.inner_cross_attn(q, kv) + else: + if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = attention_mask[:, None, ...] + attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) + + if use_window_circumstance: + output = self.inner_attn( + *self._convert_unpacked_qkv_to_packed(q, kv, bsz, attn_mask4flsh), + window_size=(self.sliding_window, 0), + ) + else: + output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, bsz, attn_mask4flsh)) + output = output.to(x.dtype) + + context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) + + else: + attn_mask = attention_mask[:, -1, :].view(bsz, 1, 1, -1) + if window_size is not None and window_size <= sequence_len_offset: + attn_mask = torch.concat( + [ + attn_mask[..., : inference_params.keep_first], + attn_mask[..., -(window_size - inference_params.keep_first) :], + ], + dim=-1, + ) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + + # wo + return self.wo(rearrange(context, "b s h d -> b s (h d)")) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 897e1363..b836ff3d 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -7,8 +7,9 @@ from torch import nn from internlm.model.modules.linear import new_linear -from internlm.model.modules.utils import Silu +from internlm.model.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger +from internlm.utils.utils import ActivationType logger = get_logger(__file__) @@ -71,10 +72,13 @@ def __init__( ): super().__init__() - # TODO: support gelu... - assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}" + assert activation_type in ( + ActivationType.swiglu.name, + ActivationType.gelu.name, + ), f"Unsupported activation type: {activation_type}" self.mlp_layer_fusion = mlp_layer_fusion + self.activation_type = activation_type hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) @@ -98,7 +102,12 @@ def forward(self, x): else: fussed_out = self.fused_w1_w3(x) w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - out = self.w2(Silu(w1_o, w3_o)) + + if self.activation_type is ActivationType.swiglu.name: + out = self.w2(Silu(w1_o, w3_o)) + else: + out = self.w2(Gelu(w1_o, w3_o)) + return out diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index b94cdd43..2a9700f8 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -2,6 +2,7 @@ layer norm modules """ +import inspect from typing import List, Union import torch @@ -12,8 +13,12 @@ Shape = Union[int, List[int], torch.Size] -def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5): +def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False): if norm_type == "rmsnorm": - return RMSNorm(normalized_shape, eps) + rmsnorm_params = inspect.signature(RMSNorm).parameters + if "add_unit_offset" in rmsnorm_params: + return RMSNorm(normalized_shape, eps, add_unit_offset) + else: + return RMSNorm(normalized_shape, eps) else: # default: layernorm return nn.LayerNorm(normalized_shape, eps) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index dd86cb1c..bf1ae048 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -20,7 +20,12 @@ def Silu(w1_o, w2_o): return F.silu(w1_o) * w2_o +def Gelu(w1_o, w2_o): + return F.gelu(w1_o) * w2_o + + Silu = torch.jit.script(Silu) +Gelu = torch.jit.script(Gelu) def update_kv_cache(kv, inference_params, layer_idx): diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 3cd43dab..34e7c007 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -35,7 +35,7 @@ torchnpu_rmsnorm_impl = False -def manual_rms_norm(my_input, weight, normalized_shape, eps): +def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False): # layer norm should always be calculated in float32 dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) @@ -48,13 +48,16 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps): if weight.dtype in [torch.float16, torch.bfloat16]: my_input = my_input.to(weight.dtype) - return weight * my_input + if add_unit_offset: + return (1 + weight) * my_input + else: + return weight * my_input class _RMSNorm(torch.nn.Module): """A generic module for RMS normalization.""" - def __init__(self, normalized_shape, eps=1e-5): + def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): super().__init__() if isinstance(normalized_shape, numbers.Integral): @@ -62,18 +65,22 @@ def __init__(self, normalized_shape, eps=1e-5): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.weight = Parameter(torch.empty(*normalized_shape)) + self.add_unit_offset = add_unit_offset self.reset_parameters() def forward(self, _input: torch.Tensor): if apex_rmsnorm_impl: _norm_func = mixed_dtype_fused_rms_norm_affine + return _norm_func(_input, self.weight, self.normalized_shape, self.eps) else: _norm_func = manual_rms_norm - - return _norm_func(_input, self.weight, self.normalized_shape, self.eps) + return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset) def reset_parameters(self): - init.ones_(self.weight) + if self.add_unit_offset: + init.zeros_(self.weight) + else: + init.ones_(self.weight) def extra_repr(self): return f"{self.normalized_shape}, eps={self.eps}, " diff --git a/internlm/model/registry.py b/internlm/model/registry.py index a1921ab6..c923ec20 100644 --- a/internlm/model/registry.py +++ b/internlm/model/registry.py @@ -3,11 +3,14 @@ from typing import Callable +from internlm.model.modeling_baichuan2 import Baichuan2 +from internlm.model.modeling_gemma import Gemma from internlm.model.modeling_internlm import InternLM1 from internlm.model.modeling_internlm2 import InternLM2 from internlm.model.modeling_llama import Llama2 from internlm.model.modeling_llava import Llava from internlm.model.modeling_moe import Internlm1MoE +from internlm.model.modeling_qwen2 import Qwen2 from internlm.utils.common import SingletonMeta from internlm.utils.utils import ModelType @@ -83,6 +86,9 @@ def register_model_initializer() -> None: model_initializer.register_module(ModelType.LLAMA2.name, Llama2) model_initializer.register_module(ModelType.INTERNLM_MoE.name, Internlm1MoE) model_initializer.register_module(ModelType.LLAVA.name, Llava) + model_initializer.register_module(ModelType.QWEN2.name, Qwen2) + model_initializer.register_module(ModelType.BAICHUAN2.name, Baichuan2) + model_initializer.register_module(ModelType.GEMMA.name, Gemma) register_model_initializer() diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index 03dee6df..ca6b3215 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -47,6 +47,9 @@ class ModelType(Enum): LLAMA2 = 3 INTERNLM_MoE = 4 LLAVA = 5 + QWEN2 = 6 + BAICHUAN2 = 7 + GEMMA = 8 class DataType(Enum): @@ -61,6 +64,11 @@ class TensorParallelMode(Enum): isp = 4 +class ActivationType(Enum): + swiglu = 1 + gelu = 2 + + def check_attention_argument(*args, **kwargs) -> str: # self, qkv, ... # self, q, kv, ....