From d9bb33fa8e23ed38936221c9e44e1d7c72728cac Mon Sep 17 00:00:00 2001
From: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com>
Date: Wed, 11 Sep 2024 14:37:58 +0800
Subject: [PATCH] Add new models (#312)
---
README-ja-JP.md | 6 +-
README-zh-Hans.md | 6 +-
README.md | 6 +-
configs/7B_baichuan2.py | 225 ++++++++
configs/7B_gemma.py | 232 +++++++++
configs/7B_qwen2.py | 232 +++++++++
internlm/data/build_dataloader.py | 6 +-
internlm/model/modeling_baichuan2.py | 639 +++++++++++++++++++++++
internlm/model/modeling_gemma.py | 752 +++++++++++++++++++++++++++
internlm/model/modeling_qwen2.py | 752 +++++++++++++++++++++++++++
internlm/model/modules/mha.py | 354 ++++++++++++-
internlm/model/modules/mlp.py | 17 +-
internlm/model/modules/norm.py | 9 +-
internlm/model/modules/utils.py | 5 +
internlm/model/ops/norm.py | 19 +-
internlm/model/registry.py | 6 +
internlm/utils/utils.py | 8 +
17 files changed, 3252 insertions(+), 22 deletions(-)
create mode 100644 configs/7B_baichuan2.py
create mode 100644 configs/7B_gemma.py
create mode 100644 configs/7B_qwen2.py
create mode 100644 internlm/model/modeling_baichuan2.py
create mode 100644 internlm/model/modeling_gemma.py
create mode 100644 internlm/model/modeling_qwen2.py
diff --git a/README-ja-JP.md b/README-ja-JP.md
index a9327ecc..009f1d27 100644
--- a/README-ja-JP.md
+++ b/README-ja-JP.md
@@ -2,7 +2,7 @@
-
+
[![Documentation Status](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest)
[![license](./doc/imgs/license.svg)](./LICENSE)
@@ -143,6 +143,10 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
diff --git a/README-zh-Hans.md b/README-zh-Hans.md
index ea38b04e..652cd8db 100644
--- a/README-zh-Hans.md
+++ b/README-zh-Hans.md
@@ -2,7 +2,7 @@
-
+
[![使用文档](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest)
[![license](./doc/imgs/license.svg)](./LICENSE)
@@ -143,6 +143,10 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
|
diff --git a/README.md b/README.md
index 9a61d573..f04f6f14 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
-
+
[![Documentation Status](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest)
[![license](./doc/imgs/license.svg)](./LICENSE)
@@ -143,6 +143,10 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
|
diff --git a/configs/7B_baichuan2.py b/configs/7B_baichuan2.py
new file mode 100644
index 00000000..fdc1b0ab
--- /dev/null
+++ b/configs/7B_baichuan2.py
@@ -0,0 +1,225 @@
+JOB_NAME = "7b_baichuan2_train"
+model_type = "BAICHUAN2"
+DO_ALERT = False
+
+VOCAB_SIZE = 125696
+SEQ_LEN = 2048
+HIDDEN_SIZE = 4096
+NUM_ATTENTION_HEAD = 32
+MLP_RATIO = 8 / 3
+NUM_LAYER = 32
+
+
+MODEL_ONLY_FOLDER = "local:llm_ckpts_baichuan2/xxxx"
+# Ckpt folder format:
+# fs: 'local:/mnt/nfs/XXX'
+SAVE_CKPT_FOLDER = "local:llm_ckpts_baichuan2"
+
+# boto3 Ckpt folder format:
+# import os
+# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
+# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
+CHECKPOINT_EVERY = 50
+ckpt = dict(
+ enable_save_ckpt=False, # enable ckpt save.
+ enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
+ save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
+ # 'load_ckpt_info' setting guide:
+ # 1. the 'path' indicate ckpt path,
+ # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
+ # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
+ # load function such as "llama"
+ load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
+ # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
+ # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
+ # with an automatic restart mechanism upon training reboot.
+ # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
+ # path specified in `load_ckpt_info` by default.
+ # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
+ # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
+ auto_resume=False,
+ checkpoint_every=CHECKPOINT_EVERY,
+ async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
+ async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
+ oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
+)
+
+TRAIN_FOLDER = None
+VALID_FOLDER = None # "/path/to/dataset"
+data = dict(
+ seq_len=SEQ_LEN,
+ # micro_num means the number of micro_batch contained in one gradient update
+ micro_num=4,
+ # packed_length = micro_bsz * SEQ_LEN
+ micro_bsz=1,
+ # defaults to the value of micro_num
+ valid_micro_num=4,
+ # defaults to 0, means disable evaluate
+ valid_every=0,
+ pack_sample_into_one=False,
+ total_steps=20,
+ skip_batches="",
+ # rampup_batch_size (str): A string with three space-separated integers representing the
+ # starting batch size, the increment, and the number of steps between
+ # each increment. For example, "192 24 8" means that the batch size (micro_num)
+ # starts at 192 and increases by 24 every 8 steps. Defaults to None.
+ # (IMPORTANT): The interval step size is 'micro_bsz'.
+ rampup_batch_size="",
+ # Datasets with less than 50 rows will be discarded
+ min_length=50,
+ train_folder=TRAIN_FOLDER,
+ valid_folder=VALID_FOLDER,
+ empty_cache_and_diag_interval=200,
+ diag_outlier_ratio=1.1,
+)
+
+grad_scaler = dict(
+ fp16=dict(
+ # the initial loss scale, defaults to 2**16
+ initial_scale=2**16,
+ # the minimum loss scale, defaults to None
+ min_scale=1,
+ # the number of steps to increase loss scale when no overflow occurs
+ growth_interval=1000,
+ ),
+ # the multiplication factor for increasing loss scale, defaults to 2
+ growth_factor=2,
+ # the multiplication factor for decreasing loss scale, defaults to 0.5
+ backoff_factor=0.5,
+ # the maximum loss scale, defaults to None
+ max_scale=2**24,
+ # the number of overflows before decreasing loss scale, defaults to 2
+ hysteresis=2,
+)
+
+hybrid_zero_optimizer = dict(
+ # Enable low_level_optimzer overlap_communication
+ overlap_sync_grad=True,
+ overlap_sync_param=False,
+ # bucket size for nccl communication params
+ reduce_bucket_size=512 * 1024 * 1024,
+ # grad clipping
+ clip_grad_norm=1.0,
+)
+
+loss = dict(
+ label_smoothing=0,
+)
+
+adam = dict(
+ lr=1e-4,
+ adam_beta1=0.9,
+ adam_beta2=0.95,
+ adam_beta2_c=0,
+ adam_eps=1e-8,
+ weight_decay=0.01,
+)
+
+lr_scheduler = dict(
+ total_steps=data["total_steps"],
+ init_steps=0, # optimizer_warmup_step
+ warmup_ratio=0.01,
+ eta_min=1e-5,
+ last_epoch=-1,
+)
+
+beta2_scheduler = dict(
+ init_beta2=adam["adam_beta2"],
+ c=adam["adam_beta2_c"],
+ cur_iter=-1,
+)
+
+use_fp32_norm = False
+model = dict(
+ checkpoint=False,
+ num_chunks=1,
+ num_attention_heads=NUM_ATTENTION_HEAD,
+ embed_split_hidden=True,
+ vocab_size=VOCAB_SIZE,
+ embed_grad_scale=1,
+ parallel_output=True,
+ hidden_size=HIDDEN_SIZE,
+ num_layers=NUM_LAYER,
+ no_bias=True,
+ mlp_ratio=MLP_RATIO,
+ apply_post_layer_norm=False,
+ dtype="torch.bfloat16",
+ norm_type="rmsnorm",
+ layer_norm_epsilon=1e-6,
+ use_flash_attn=True,
+ # Whether the odd and even columns of the query and key in the model are normally interleaved.
+ # If it's True, the model's odd and even columns are normally ordered; if it's False,
+ # it means that the model has prematurely concatenated all odd columns and even columns in front
+ # and back, in order to improve the RoPE's computational efficiency.
+ # Example:
+ # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
+ # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
+ qk_interleaved=False,
+)
+
+"""
+zero1 parallel (dict):
+ 1. size: int
+ * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
+ so parameters will be divided within the range of dp.
+ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
+ * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
+ For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
+ 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
+tensor parallel (dict):
+ 1. size: int, the size of tensor parallel.
+ 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
+ defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
+ msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
+ fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
+ isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
+pipeline parallel (dict):
+ 1. size: int, the size of pipeline parallel.
+ 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
+ defaults to False.
+weight parallel (dict):
+ 1. size: int, the size of weight parallel.
+ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
+ 3. memory_pool: bool, enable/disable memory pool, defaults to False.
+"""
+parallel = dict(
+ zero1=dict(size=-1),
+ tensor=dict(size=1, mode="mtp"),
+ pipeline=dict(size=1, interleaved_overlap=True),
+ weight=dict(size=1, overlap=True, memory_pool=True),
+)
+
+cudnn_deterministic = False
+cudnn_benchmark = False
+
+monitor = dict(
+ # feishu alert configs
+ alert=dict(
+ enable_feishu_alert=DO_ALERT,
+ feishu_alert_address=None, # feishu webhook to send alert message
+ light_monitor_address=None, # light_monitor address to send heartbeat
+ alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
+ ),
+ tensorboard=dict(
+ queue_max_length=10,
+ ),
+)
+
+# metric_dtype can be "fp32" or other string
+# only when set to "fp32" will use fp32 to calc in metrics
+# metric_dtype = "fp32"
+
+generation = dict(
+ ckpt_folder="/path/to/saved/ckpt",
+ output_folder="/path/to/save/generation",
+ batch_size=1,
+ eos_id=[2, 0],
+ bos_id=1,
+ max_length=100,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ top_p=1.0,
+ repetition_penalty=1,
+ length_penalty=1.0,
+)
diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py
new file mode 100644
index 00000000..6ef8c99b
--- /dev/null
+++ b/configs/7B_gemma.py
@@ -0,0 +1,232 @@
+JOB_NAME = "7b_gemma_train"
+model_type = "GEMMA"
+DO_ALERT = False
+
+VOCAB_SIZE = 256000
+SEQ_LEN = 2048
+HIDDEN_SIZE = 3072
+NUM_ATTENTION_HEAD = 16
+NUM_KV_ATTENTION_HEAD = 16
+HEAD_DIM = 256
+MLP_RATIO = 8
+NUM_LAYER = 28
+
+
+MODEL_ONLY_FOLDER = "local:llm_ckpts_gemma/xxxx"
+# Ckpt folder format:
+# fs: 'local:/mnt/nfs/XXX'
+SAVE_CKPT_FOLDER = "local:llm_ckpts_gemma"
+
+# boto3 Ckpt folder format:
+# import os
+# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
+# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
+CHECKPOINT_EVERY = 50
+ckpt = dict(
+ enable_save_ckpt=False, # enable ckpt save.
+ enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
+ save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
+ # 'load_ckpt_info' setting guide:
+ # 1. the 'path' indicate ckpt path,
+ # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
+ # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
+ # load function such as "llama"
+ load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
+ # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
+ # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
+ # with an automatic restart mechanism upon training reboot.
+ # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
+ # path specified in `load_ckpt_info` by default.
+ # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
+ # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
+ auto_resume=False,
+ checkpoint_every=CHECKPOINT_EVERY,
+ async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
+ async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
+ oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
+)
+
+TRAIN_FOLDER = None
+VALID_FOLDER = None # "/path/to/dataset"
+data = dict(
+ seq_len=SEQ_LEN,
+ # micro_num means the number of micro_batch contained in one gradient update
+ micro_num=4,
+ # packed_length = micro_bsz * SEQ_LEN
+ micro_bsz=1,
+ # defaults to the value of micro_num
+ valid_micro_num=4,
+ # defaults to 0, means disable evaluate
+ valid_every=0,
+ pack_sample_into_one=False,
+ total_steps=20,
+ skip_batches="",
+ # rampup_batch_size (str): A string with three space-separated integers representing the
+ # starting batch size, the increment, and the number of steps between
+ # each increment. For example, "192 24 8" means that the batch size (micro_num)
+ # starts at 192 and increases by 24 every 8 steps. Defaults to None.
+ # (IMPORTANT): The interval step size is 'micro_bsz'.
+ rampup_batch_size="",
+ # Datasets with less than 50 rows will be discarded
+ min_length=50,
+ train_folder=TRAIN_FOLDER,
+ valid_folder=VALID_FOLDER,
+ empty_cache_and_diag_interval=200,
+ diag_outlier_ratio=1.1,
+)
+
+grad_scaler = dict(
+ fp16=dict(
+ # the initial loss scale, defaults to 2**16
+ initial_scale=2**16,
+ # the minimum loss scale, defaults to None
+ min_scale=1,
+ # the number of steps to increase loss scale when no overflow occurs
+ growth_interval=1000,
+ ),
+ # the multiplication factor for increasing loss scale, defaults to 2
+ growth_factor=2,
+ # the multiplication factor for decreasing loss scale, defaults to 0.5
+ backoff_factor=0.5,
+ # the maximum loss scale, defaults to None
+ max_scale=2**24,
+ # the number of overflows before decreasing loss scale, defaults to 2
+ hysteresis=2,
+)
+
+hybrid_zero_optimizer = dict(
+ # Enable low_level_optimzer overlap_communication
+ overlap_sync_grad=True,
+ overlap_sync_param=False,
+ # bucket size for nccl communication params
+ reduce_bucket_size=512 * 1024 * 1024,
+ # grad clipping
+ clip_grad_norm=1.0,
+)
+
+loss = dict(
+ label_smoothing=0,
+)
+
+adam = dict(
+ lr=1e-4,
+ adam_beta1=0.9,
+ adam_beta2=0.95,
+ adam_beta2_c=0,
+ adam_eps=1e-8,
+ weight_decay=0.01,
+)
+
+lr_scheduler = dict(
+ total_steps=data["total_steps"],
+ init_steps=0, # optimizer_warmup_step
+ warmup_ratio=0.01,
+ eta_min=1e-5,
+ last_epoch=-1,
+)
+
+beta2_scheduler = dict(
+ init_beta2=adam["adam_beta2"],
+ c=adam["adam_beta2_c"],
+ cur_iter=-1,
+)
+
+use_fp32_norm = False
+model = dict(
+ checkpoint=False,
+ num_chunks=1,
+ num_attention_heads=NUM_ATTENTION_HEAD,
+ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
+ max_position_embeddings=8192,
+ embed_split_hidden=True,
+ vocab_size=VOCAB_SIZE,
+ embed_grad_scale=1,
+ parallel_output=True,
+ hidden_size=HIDDEN_SIZE,
+ num_layers=NUM_LAYER,
+ no_bias=True,
+ mlp_ratio=MLP_RATIO,
+ apply_post_layer_norm=False,
+ dtype="torch.bfloat16",
+ add_unit_offset=True,
+ norm_type="rmsnorm",
+ layer_norm_epsilon=1e-6,
+ head_dim=HEAD_DIM,
+ use_flash_attn=True,
+ # Whether the odd and even columns of the query and key in the model are normally interleaved.
+ # If it's True, the model's odd and even columns are normally ordered; if it's False,
+ # it means that the model has prematurely concatenated all odd columns and even columns in front
+ # and back, in order to improve the RoPE's computational efficiency.
+ # Example:
+ # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
+ # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
+ qk_interleaved=False,
+ use_swiglu=False,
+)
+
+"""
+zero1 parallel (dict):
+ 1. size: int
+ * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
+ so parameters will be divided within the range of dp.
+ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
+ * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
+ For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
+ 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
+tensor parallel (dict):
+ 1. size: int, the size of tensor parallel.
+ 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
+ defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
+ msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
+ fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
+ isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
+pipeline parallel (dict):
+ 1. size: int, the size of pipeline parallel.
+ 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
+ defaults to False.
+weight parallel (dict):
+ 1. size: int, the size of weight parallel.
+ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
+ 3. memory_pool: bool, enable/disable memory pool, defaults to False.
+"""
+parallel = dict(
+ zero1=dict(size=-1),
+ tensor=dict(size=1, mode="mtp"),
+ pipeline=dict(size=1, interleaved_overlap=True),
+ weight=dict(size=1, overlap=True, memory_pool=True),
+)
+
+cudnn_deterministic = False
+cudnn_benchmark = False
+
+monitor = dict(
+ # feishu alert configs
+ alert=dict(
+ enable_feishu_alert=DO_ALERT,
+ feishu_alert_address=None, # feishu webhook to send alert message
+ light_monitor_address=None, # light_monitor address to send heartbeat
+ alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
+ ),
+ tensorboard=dict(
+ queue_max_length=10,
+ ),
+)
+
+# metric_dtype can be "fp32" or other string
+# only when set to "fp32" will use fp32 to calc in metrics
+# metric_dtype = "fp32"
+
+generation = dict(
+ ckpt_folder="/path/to/saved/ckpt",
+ output_folder="/path/to/save/generation",
+ batch_size=1,
+ eos_id=[2, 0],
+ bos_id=1,
+ max_length=100,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ top_p=1.0,
+ repetition_penalty=1,
+ length_penalty=1.0,
+)
diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py
new file mode 100644
index 00000000..07e572d2
--- /dev/null
+++ b/configs/7B_qwen2.py
@@ -0,0 +1,232 @@
+JOB_NAME = "7b_qwen2_train"
+model_type = "QWEN2"
+DO_ALERT = False
+
+VOCAB_SIZE = 152064
+SEQ_LEN = 2048
+HIDDEN_SIZE = 3584
+NUM_ATTENTION_HEAD = 28
+NUM_KV_ATTENTION_HEAD = 4
+MLP_RATIO = 5.25
+NUM_LAYER = 28
+
+
+MODEL_ONLY_FOLDER = "local:llm_ckpts_qwen2/xxxx/"
+# Ckpt folder format:
+# fs: 'local:/mnt/nfs/XXX'
+SAVE_CKPT_FOLDER = "local:llm_ckpts_qwen2"
+
+# boto3 Ckpt folder format:
+# import os
+# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
+# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
+CHECKPOINT_EVERY = 50
+ckpt = dict(
+ enable_save_ckpt=False, # enable ckpt save.
+ enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
+ save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
+ # 'load_ckpt_info' setting guide:
+ # 1. the 'path' indicate ckpt path,
+ # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
+ # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
+ # load function such as "llama"
+ load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
+ # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
+ # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
+ # with an automatic restart mechanism upon training reboot.
+ # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
+ # path specified in `load_ckpt_info` by default.
+ # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
+ # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
+ auto_resume=False,
+ checkpoint_every=CHECKPOINT_EVERY,
+ async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
+ async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
+ oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
+)
+
+TRAIN_FOLDER = None
+VALID_FOLDER = None # "/path/to/dataset"
+data = dict(
+ seq_len=SEQ_LEN,
+ # micro_num means the number of micro_batch contained in one gradient update
+ micro_num=4,
+ # packed_length = micro_bsz * SEQ_LEN
+ micro_bsz=1,
+ # defaults to the value of micro_num
+ valid_micro_num=4,
+ # defaults to 0, means disable evaluate
+ valid_every=0,
+ pack_sample_into_one=False,
+ total_steps=20,
+ skip_batches="",
+ # rampup_batch_size (str): A string with three space-separated integers representing the
+ # starting batch size, the increment, and the number of steps between
+ # each increment. For example, "192 24 8" means that the batch size (micro_num)
+ # starts at 192 and increases by 24 every 8 steps. Defaults to None.
+ # (IMPORTANT): The interval step size is 'micro_bsz'.
+ rampup_batch_size="",
+ # Datasets with less than 50 rows will be discarded
+ min_length=50,
+ train_folder=TRAIN_FOLDER,
+ valid_folder=VALID_FOLDER,
+ empty_cache_and_diag_interval=200,
+ diag_outlier_ratio=1.1,
+)
+
+grad_scaler = dict(
+ fp16=dict(
+ # the initial loss scale, defaults to 2**16
+ initial_scale=2**16,
+ # the minimum loss scale, defaults to None
+ min_scale=1,
+ # the number of steps to increase loss scale when no overflow occurs
+ growth_interval=1000,
+ ),
+ # the multiplication factor for increasing loss scale, defaults to 2
+ growth_factor=2,
+ # the multiplication factor for decreasing loss scale, defaults to 0.5
+ backoff_factor=0.5,
+ # the maximum loss scale, defaults to None
+ max_scale=2**24,
+ # the number of overflows before decreasing loss scale, defaults to 2
+ hysteresis=2,
+)
+
+hybrid_zero_optimizer = dict(
+ # Enable low_level_optimzer overlap_communication
+ overlap_sync_grad=True,
+ overlap_sync_param=False,
+ # bucket size for nccl communication params
+ reduce_bucket_size=512 * 1024 * 1024,
+ # grad clipping
+ clip_grad_norm=1.0,
+)
+
+loss = dict(
+ label_smoothing=0,
+)
+
+adam = dict(
+ lr=1e-4,
+ adam_beta1=0.9,
+ adam_beta2=0.95,
+ adam_beta2_c=0,
+ adam_eps=1e-8,
+ weight_decay=0.01,
+)
+
+lr_scheduler = dict(
+ total_steps=data["total_steps"],
+ init_steps=0, # optimizer_warmup_step
+ warmup_ratio=0.01,
+ eta_min=1e-5,
+ last_epoch=-1,
+)
+
+beta2_scheduler = dict(
+ init_beta2=adam["adam_beta2"],
+ c=adam["adam_beta2_c"],
+ cur_iter=-1,
+)
+
+use_fp32_norm = False
+model = dict(
+ checkpoint=False,
+ num_chunks=1,
+ num_attention_heads=NUM_ATTENTION_HEAD,
+ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
+ embed_split_hidden=True,
+ vocab_size=VOCAB_SIZE,
+ embed_grad_scale=1,
+ parallel_output=True,
+ hidden_size=HIDDEN_SIZE,
+ num_layers=NUM_LAYER,
+ qkv_bias=True,
+ o_bias=False,
+ mlp_ratio=MLP_RATIO,
+ apply_post_layer_norm=False,
+ dtype="torch.bfloat16",
+ norm_type="rmsnorm",
+ layer_norm_epsilon=1e-6,
+ use_flash_attn=True,
+ # Whether the odd and even columns of the query and key in the model are normally interleaved.
+ # If it's True, the model's odd and even columns are normally ordered; if it's False,
+ # it means that the model has prematurely concatenated all odd columns and even columns in front
+ # and back, in order to improve the RoPE's computational efficiency.
+ # Example:
+ # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
+ # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
+ qk_interleaved=False,
+ rope_base=1000000,
+ use_sliding_window=False,
+ sliding_window=32768,
+ max_window_layers=28,
+)
+
+"""
+zero1 parallel (dict):
+ 1. size: int
+ * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
+ so parameters will be divided within the range of dp.
+ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
+ * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
+ For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
+ 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
+tensor parallel (dict):
+ 1. size: int, the size of tensor parallel.
+ 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
+ defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
+ msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
+ fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
+ isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
+pipeline parallel (dict):
+ 1. size: int, the size of pipeline parallel.
+ 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
+ defaults to False.
+weight parallel (dict):
+ 1. size: int, the size of weight parallel.
+ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
+ 3. memory_pool: bool, enable/disable memory pool, defaults to False.
+"""
+parallel = dict(
+ zero1=dict(size=-1),
+ tensor=dict(size=1, mode="mtp"),
+ pipeline=dict(size=1, interleaved_overlap=True),
+ weight=dict(size=1, overlap=True, memory_pool=True),
+)
+
+cudnn_deterministic = False
+cudnn_benchmark = False
+
+monitor = dict(
+ # feishu alert configs
+ alert=dict(
+ enable_feishu_alert=DO_ALERT,
+ feishu_alert_address=None, # feishu webhook to send alert message
+ light_monitor_address=None, # light_monitor address to send heartbeat
+ alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
+ ),
+ tensorboard=dict(
+ queue_max_length=10,
+ ),
+)
+
+# metric_dtype can be "fp32" or other string
+# only when set to "fp32" will use fp32 to calc in metrics
+# metric_dtype = "fp32"
+
+generation = dict(
+ ckpt_folder="/path/to/saved/ckpt",
+ output_folder="/path/to/save/generation",
+ batch_size=1,
+ eos_id=[2, 0],
+ bos_id=1,
+ max_length=100,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ top_p=1.0,
+ repetition_penalty=1,
+ length_penalty=1.0,
+)
diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py
index e7f581dc..6937b8e4 100644
--- a/internlm/data/build_dataloader.py
+++ b/internlm/data/build_dataloader.py
@@ -44,7 +44,7 @@ def get_tokenized_train_loader_items(data_cfg):
if data_cfg.get("is_multimodal", False):
image_token_size = int(data_cfg.image_size // data_cfg.patch_size) ** 2
train_ds = RandomDatasetMultimodal(
- num_samples=100000,
+ num_samples=gpc.get_world_size(ParallelMode.DATA) * 500,
max_len=data_cfg.seq_len,
image_size=data_cfg.image_size,
image_token_size=image_token_size,
@@ -54,7 +54,9 @@ def get_tokenized_train_loader_items(data_cfg):
)
else:
train_ds = RandomDataset(
- num_samples=1000000, max_len=data_cfg.seq_len, fixed_seqlen=data_cfg.fixed_random_dataset_seqlen
+ num_samples=gpc.get_world_size(ParallelMode.DATA) * 500,
+ max_len=data_cfg.seq_len,
+ fixed_seqlen=data_cfg.fixed_random_dataset_seqlen,
)
if data_cfg.pack_sample_into_one:
diff --git a/internlm/model/modeling_baichuan2.py b/internlm/model/modeling_baichuan2.py
new file mode 100644
index 00000000..8674b811
--- /dev/null
+++ b/internlm/model/modeling_baichuan2.py
@@ -0,0 +1,639 @@
+# Copyright (c) InternLM. All rights reserved.
+import math
+import os
+from typing import Optional
+
+import torch
+from einops import rearrange
+from torch import nn
+from tqdm import tqdm
+
+from internlm.accelerator import get_accelerator
+from internlm.core.context import ParallelMode
+from internlm.core.context.parallel_context import global_context as gpc
+from internlm.initialize.initialize_tensor import (
+ normal_,
+ scaled_init_method_normal,
+ scaled_init_method_uniform,
+ uniform_,
+)
+from internlm.model.base_model import BaseModel
+from internlm.model.modules.embedding import Embedding1D
+from internlm.model.modules.linear import new_linear
+from internlm.model.modules.mha import MHA
+from internlm.model.modules.mlp import new_feed_forward
+from internlm.model.modules.norm import new_layer_norm
+from internlm.model.utils import (
+ convert_attn_args_to_kwargs,
+ convert_attn_kwargs_to_args,
+)
+from internlm.solver.activation_checkpoint import activation_checkpoint
+from internlm.utils.logger import get_logger
+from internlm.utils.storage_manager import get_fns, llm_load, llm_save
+from transformers.modeling_utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ shard_checkpoint,
+)
+
+internlm_accelerator = get_accelerator()
+logger = get_logger(__file__)
+
+
+class Baichuan2Decoder(nn.Module):
+ """
+ 1D Packed Flash Llama Layer.
+
+ Args:
+ hidden_size (int): The hidden size of model. 768 by default.
+ num_attention_heads (int): The number of attention heads. 12 by default.
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
+ drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
+ dtype (torch.dtype): Type of data. torch.float by default.
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
+ layer_idx (int): The index of current layer. 0 by default.
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
+ device (Optional[Union[str, torch.device]]): The device will be used.
+ norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
+ attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default,
+ attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default,
+ ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
+ otherwise init fc1 weight in ffn. 0.006 by default,
+ ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default,
+ init_type (str): Initialization type. Use uniform or normal. "normal" by default,
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0,
+ drop_rate: float = 0.0,
+ dtype: torch.dtype = torch.float,
+ layer_norm_epsilon: float = 1e-6,
+ checkpoint: bool = False,
+ layer_idx: int = 0,
+ use_dynamic_ntk_rope: bool = False,
+ residual_in_fp32: bool = False,
+ device: Optional[torch.device] = None,
+ apply_post_layer_norm: bool = False,
+ fused_dropout_add_ln: bool = True,
+ no_bias: bool = False,
+ norm_type: str = "rmsnorm",
+ qk_interleaved: bool = False,
+ dropout_selective_checkpoint: bool = True,
+ use_scaled_init: bool = True,
+ use_swiglu: bool = True,
+ attn_wqkv_init_std: float = 0.006,
+ attn_other_init_std: float = 0.0015,
+ ffn_uplayer_init_std: float = 0.006,
+ ffn_other_init_std: float = 0.0015,
+ init_type: str = "normal",
+ rope_base: int = 10000,
+ mlp_layer_fusion: bool = False,
+ multiple_of: int = 256,
+ max_position_embeddings: int = 2048,
+ ):
+ super().__init__()
+ self.checkpoint = checkpoint
+ # dropout selective checkpoint can only be enabled when checkpoint is disabled.
+ self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
+ self.layer_idx = layer_idx
+ self.prenorm = not apply_post_layer_norm
+ assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
+ self.fused_dropout_add_ln = fused_dropout_add_ln
+ self.attn_wqkv_init_std = attn_wqkv_init_std
+ self.attn_other_init_std = attn_other_init_std
+ self.ffn_uplayer_init_std = ffn_uplayer_init_std
+ self.ffn_other_init_std = ffn_other_init_std
+
+ head_dim = hidden_size // num_attention_heads
+
+ self.attention = MHA(
+ embed_dim=hidden_size,
+ num_heads=num_attention_heads,
+ max_position_embeddings=max_position_embeddings,
+ bias=not no_bias,
+ dropout=attn_drop_rate,
+ softmax_scale=1 / math.sqrt(head_dim),
+ causal=True,
+ layer_idx=layer_idx,
+ use_dynamic_ntk_rope=use_dynamic_ntk_rope,
+ rope_base=rope_base,
+ rotary_emb_dim=head_dim,
+ rotary_emb_scale_base=0,
+ device=device,
+ dtype=dtype,
+ qk_interleaved=qk_interleaved,
+ enable_qkv_fusion=True,
+ out_bias=False,
+ )
+
+ self.dropout1 = nn.Dropout(drop_rate)
+ self.dropout2 = nn.Dropout(drop_rate)
+ self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
+ self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
+
+ self.feed_forward = new_feed_forward(
+ hidden_size,
+ int(hidden_size * mlp_ratio),
+ out_features=hidden_size,
+ bias=False,
+ device=device,
+ dtype=dtype,
+ mlp_layer_fusion=mlp_layer_fusion,
+ multiple_of=multiple_of,
+ # TODO: to support more activation functions
+ activation_type="swiglu" if use_swiglu else "gelu",
+ )
+
+ self.use_swiglu = use_swiglu
+ self.use_scaled_init = use_scaled_init
+ self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
+ self.return_residual = False
+
+ if init_type == "normal":
+ self.init_func = normal_
+ self.scaled_init_func = scaled_init_method_normal
+ else:
+ self.init_func = uniform_
+ self.scaled_init_func = scaled_init_method_uniform
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ with torch.no_grad():
+ for name, param in self.attention.named_parameters():
+ if param.ndim == 1:
+ param.data.zero_()
+ elif "wq" in name or "wk" in name or "wv" in name:
+ self.init_func(std=self.attn_wqkv_init_std)(param.data)
+ elif self.use_scaled_init: # wo
+ self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(std=self.attn_other_init_std)(param.data)
+
+ for name, param in self.feed_forward.named_parameters():
+ if self.use_swiglu:
+ if self.use_scaled_init and "w2" in name:
+ self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ # candidate: w1, w3, fused_w1_w3
+ self.init_func(
+ std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
+ )(param.data)
+ else:
+ if self.use_scaled_init and "fc1" not in name:
+ self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
+ param.data
+ )
+
+ def forward(self, hidden_states, residual=None, **kwargs):
+ if self.checkpoint and self.training:
+ args = convert_attn_kwargs_to_args(kwargs)
+ return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
+ else:
+ return self._forward(hidden_states, residual, **kwargs)
+
+ def _forward(self, hidden_states, residual, *args, **kwargs):
+ r"""Pass the input through the encoder layer.
+
+ Args:
+ hidden_states: the sequence to the encoder layer (required).
+ residual: hidden_states = Attn/MLP(LN(residual))
+ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
+ indexes: the length of index is same as hidden states, which stand for the current position
+ """
+ if self.prenorm:
+
+ def _dropout_and_norm_attn(_residual, _hidden_states):
+ _dropped = self.dropout1(_hidden_states)
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
+ _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
+
+ return _residual, _hidden_states
+
+ if self.dropout_selective_checkpoint:
+ residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
+ else:
+ residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
+
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
+ hidden_states = self.attention(hidden_states, **mixer_kwargs)
+
+ if not isinstance(self.feed_forward, nn.Identity):
+ if not self.fused_dropout_add_ln:
+
+ def _dropout_and_norm_ffn(_residual, _hidden_states):
+ _dropped = self.dropout2(_hidden_states)
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
+ _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
+
+ return _residual, _hidden_states
+
+ if self.dropout_selective_checkpoint:
+ residual, hidden_states = activation_checkpoint(
+ _dropout_and_norm_ffn, False, residual, hidden_states
+ )
+ else:
+ residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
+
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ hidden_states = self.feed_forward(hidden_states)
+
+ return hidden_states + residual
+ else:
+ assert residual is None
+
+ mixer_out = self.attention(hidden_states, **kwargs)
+ if self.return_residual: # mixer out is actually a pair here
+ mixer_out, hidden_states = mixer_out
+ hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
+ dtype=self.attention_norm.weight.dtype
+ )
+ if not isinstance(self.feed_forward, nn.Identity):
+ mlp_out = self.feed_forward(hidden_states)
+ if self.return_residual: # mlp out is actually a pair here
+ mlp_out, hidden_states = mlp_out
+ hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
+ dtype=self.ffn_norm.weight.dtype
+ )
+ return hidden_states
+
+
+class Baichuan2(BaseModel):
+ """
+ 1D Packed Flash Llama.
+
+ Args:
+ num_layers (int): The number of layer. 12 by default.
+ hidden_size (int): The size of hidden state. 768 by default.
+ num_attention_heads (int): The number of attention head. 12 by default.
+ vocab_size (int): The size of vocabulary. 50304 by default.
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
+ attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
+ drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
+ dtype (torch.dtype): The type of data. torch.float by default.
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
+ checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
+ of layers. 1.0 by default.
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
+ first (bool): Whether input embedding layer or not. False by default.
+ last (bool): Whether output embedding layer or not. False by default.
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
+ start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
+ device (Optional[Union[str, torch.device]]): The device will be used. None by default.
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
+ norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
+ qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved.
+ embedding_init_std (float): std used to init embedding weight. 0.0052 by default,
+ attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default,
+ attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default,
+ ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
+ otherwise init fc1 weight in ffn. 0.006 by default,
+ ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default,
+ out_head_init_std (float): std used to init output lmhead weight. 0.0052 by default,
+ init_type (str): Initialization type. Use uniform or normal. "normal" by default,
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
+ """
+
+ def __init__(
+ self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.0,
+ drop_rate: float = 0.0,
+ max_position_embeddings: int = 2048,
+ dtype: torch.dtype = torch.float,
+ checkpoint: float = 1.0,
+ layer_norm_epsilon: float = 1e-5,
+ first: bool = False,
+ last: bool = False,
+ embed_grad_scale: float = 0.1,
+ parallel_output: bool = True,
+ start_layer_idx: int = 0,
+ use_dynamic_ntk_rope: bool = False,
+ device: Optional[torch.device] = None,
+ apply_post_layer_norm=False,
+ no_bias=False,
+ residual_in_fp32: bool = False,
+ norm_type: str = "rmsnorm",
+ qk_interleaved: bool = False,
+ is_reward: bool = False,
+ dropout_selective_checkpoint: bool = True,
+ use_scaled_init: bool = True,
+ use_swiglu: bool = True,
+ embedding_init_std: float = 0.0052,
+ attn_wqkv_init_std: float = 0.006,
+ attn_other_init_std: float = 0.0015,
+ ffn_uplayer_init_std: float = 0.006,
+ ffn_other_init_std: float = 0.0015,
+ out_head_init_std: float = 0.0052,
+ init_type: str = "normal",
+ norm_head: bool = False,
+ rope_base: int = 10000,
+ mlp_layer_fusion: bool = False,
+ multiple_of: int = 256,
+ ):
+ super().__init__()
+
+ checkpoint_layer_num = int(num_layers * checkpoint)
+ self.embed_grad_scale = embed_grad_scale
+ self.parallel_output = parallel_output
+
+ if first:
+ self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
+
+ for _, param in self.tok_embeddings.named_parameters():
+ if init_type == "normal":
+ normal_(std=embedding_init_std)(param)
+ else:
+ uniform_(std=embedding_init_std)(param)
+
+ self.layers = nn.ModuleList(
+ [
+ Baichuan2Decoder(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ mlp_ratio=mlp_ratio,
+ attn_drop_rate=attn_drop_rate,
+ drop_rate=drop_rate,
+ max_position_embeddings=max_position_embeddings,
+ dtype=dtype,
+ layer_norm_epsilon=layer_norm_epsilon,
+ checkpoint=lid < checkpoint_layer_num,
+ layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
+ use_dynamic_ntk_rope=use_dynamic_ntk_rope,
+ residual_in_fp32=residual_in_fp32,
+ device=device,
+ apply_post_layer_norm=apply_post_layer_norm,
+ fused_dropout_add_ln=False,
+ no_bias=no_bias,
+ norm_type=norm_type,
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
+ use_scaled_init=use_scaled_init,
+ use_swiglu=use_swiglu,
+ qk_interleaved=qk_interleaved,
+ attn_wqkv_init_std=attn_wqkv_init_std,
+ attn_other_init_std=attn_other_init_std,
+ ffn_uplayer_init_std=ffn_uplayer_init_std,
+ ffn_other_init_std=ffn_other_init_std,
+ init_type=init_type,
+ rope_base=rope_base,
+ mlp_layer_fusion=mlp_layer_fusion,
+ multiple_of=multiple_of,
+ )
+ for lid in range(num_layers)
+ ]
+ )
+
+ if last:
+ if not apply_post_layer_norm:
+ self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
+
+ self.output = new_linear(
+ name="output",
+ in_features=hidden_size,
+ out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
+ bias=False,
+ device=device,
+ dtype=dtype,
+ is_reward=is_reward,
+ weight_scale=embed_grad_scale,
+ norm_head=norm_head,
+ )
+
+ for _, param in self.output.named_parameters():
+ if init_type == "normal":
+ normal_(std=out_head_init_std)(param)
+ else:
+ uniform_(std=out_head_init_std)(param)
+
+ def forward(self, hidden_states=None, input_ids=None, **kwargs):
+ # attention_mask: compute attention on the places where the value is 1
+ if hasattr(self, "tok_embeddings") and input_ids is not None:
+ hidden_states = self.tok_embeddings(input_ids)
+ if self.embed_grad_scale != 1:
+ hidden_states = (
+ self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
+ )
+
+ for _, block in enumerate(self.layers):
+ hidden_states = block(hidden_states, residual=None, **kwargs)
+
+ if hasattr(self, "norm"):
+ hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
+ if hasattr(self, "output"):
+ hidden_states = self.output(hidden_states)
+
+ return hidden_states
+
+ @staticmethod
+ def load_hf_weights(folder: str, model: nn.Module) -> None:
+ assert folder is not None, "Please specify the folder of the pretrained model"
+ if gpc.is_rank_for_log():
+ logger.info(f"Loading pretrained model from {folder}")
+
+ fns = get_fns(folder)
+ model_fns = [
+ os.path.join(folder, fn)
+ for fn in fns
+ if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
+ or (fn.endswith(".safetensors") and fn.startswith("model"))
+ ]
+ model_fns.sort()
+
+ state_dict = {}
+ for model_fn in model_fns:
+ state_dict.update(llm_load(model_fn, map_location="cpu"))
+
+ tp_size = gpc.get_world_size(ParallelMode.TENSOR)
+ tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
+ wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
+ wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
+ tp_mode = gpc.config.parallel.tensor["mode"]
+ split_size = wp_size if tp_mode == "isp" else tp_size
+ local_rank = wp_rank if tp_mode == "isp" else tp_rank
+ row_dim = 0 if tp_mode == "isp" else 1
+ if gpc.config.model.get("embed_split_hidden", True):
+ embed_concat_dim = 1
+ else:
+ embed_concat_dim = 0
+
+ new_state_dict = {}
+
+ # embedding
+ if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
+ new_state_dict["tok_embeddings.weight"] = torch.chunk(
+ state_dict.pop("model.embed_tokens.weight"),
+ split_size,
+ dim=embed_concat_dim,
+ )[local_rank]
+
+ for idx, i in enumerate(range(model.first_layer, model.last_layer)):
+ layer_ids = i
+
+ # attn
+ state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.W_pack.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.out_proj.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
+ split_size,
+ dim=row_dim,
+ )[local_rank]
+
+ # ffn
+ state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
+ split_size,
+ dim=row_dim,
+ )[local_rank]
+
+ # attn norm
+ state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
+ f"model.layers.{layer_ids}.input_layernorm.weight"
+ )
+ # ffn norm
+ state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
+ f"model.layers.{layer_ids}.post_attention_layernorm.weight"
+ )
+
+ # replace value within decoder layer
+ for name in list(state_dict.keys()):
+ if name.startswith(f"layers.{i}"):
+ new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)
+
+ # output
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
+ new_state_dict["output.weight"] = torch.chunk(
+ state_dict.pop("lm_head.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")
+
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
+ if len(state_dict) > 0:
+ logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")
+
+ if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
+ logger.info(
+ f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
+ f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
+ )
+
+ internlm_accelerator.empty_cache()
+
+ @staticmethod
+ def convert_internevo2hf_weights(src: str, tgt: str) -> None:
+ def permute(qkv, num_heads, num_kv_heads, head_dim, qk_interleaved=False):
+ if not qk_interleaved:
+ return qkv
+ q_per_kv = num_heads // num_kv_heads
+ qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim)
+ q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :]
+ q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
+ k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
+ qkv = torch.cat((q, k, v), dim=2)
+ qkv = rearrange(qkv, "o g n i -> o (g n i)").T
+ return qkv
+
+ model_config = gpc.config.model
+ tp_mode = gpc.config.parallel.tensor["mode"]
+ row_dim = 0 if tp_mode == "isp" else 1
+ if model_config["embed_split_hidden"]:
+ embed_concat_dim = 1
+ else:
+ embed_concat_dim = 0
+
+ # load states
+ states, num_shards = Baichuan2.load_sharded_states(src)
+
+ # convert state_dict
+ state_dict = {}
+ embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
+ for layer_i in tqdm(range(model_config["num_layers"])):
+ # attn norm, ffn norm
+ state_dict.update(
+ {
+ f"model.layers.{layer_i}.input_layernorm.weight": states[0][
+ f"layers.{layer_i}.attention_norm.weight"
+ ].clone(),
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
+ f"layers.{layer_i}.ffn_norm.weight"
+ ].clone(),
+ }
+ )
+ # attn
+ state_dict[f"model.layers.{layer_i}.self_attn.W_pack.weight"] = permute(
+ torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0),
+ num_heads=model_config["num_attention_heads"],
+ # num_kv_attention_heads equals to num_attention_heads in MHA
+ num_kv_heads=model_config["num_attention_heads"],
+ head_dim=model_config["hidden_size"] // model_config["num_attention_heads"],
+ qk_interleaved=model_config.get("qk_interleaved", False),
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.out_proj.weight"] for i in range(num_shards)], dim=row_dim
+ )
+ # ffn
+ state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
+ )
+ state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
+ )
+ state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
+ )
+ # embedding, output
+ for embedding_key in embedding_key_list:
+ if embedding_key in states[0]:
+ break
+ if embedding_key is None:
+ raise KeyError("Cannot find embedding key!")
+ state_dict.update(
+ {
+ "model.norm.weight": states[0]["norm.weight"],
+ "model.embed_tokens.weight": torch.cat(
+ [states[i][embedding_key] for i in range(num_shards)], dim=embed_concat_dim
+ ),
+ "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
+ },
+ )
+
+ # save state_dict to hf format
+ shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
+ for shard_file, shard in shards.items():
+ llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
+ if index is not None:
+ llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
diff --git a/internlm/model/modeling_gemma.py b/internlm/model/modeling_gemma.py
new file mode 100644
index 00000000..a43843a8
--- /dev/null
+++ b/internlm/model/modeling_gemma.py
@@ -0,0 +1,752 @@
+# Copyright (c) InternLM. All rights reserved.
+import math
+import os
+from typing import Optional
+
+import torch
+from torch import nn
+from tqdm import tqdm
+
+from internlm.accelerator import get_accelerator
+from internlm.core.context import ParallelMode
+from internlm.core.context.parallel_context import global_context as gpc
+from internlm.initialize.initialize_tensor import (
+ normal_,
+ scaled_init_method_normal,
+ scaled_init_method_uniform,
+ uniform_,
+)
+from internlm.model.base_model import BaseModel
+from internlm.model.modules.embedding import Embedding1D
+from internlm.model.modules.linear import new_linear
+from internlm.model.modules.mha import GQA
+from internlm.model.modules.mlp import new_feed_forward
+from internlm.model.modules.norm import new_layer_norm
+from internlm.model.utils import (
+ convert_attn_args_to_kwargs,
+ convert_attn_kwargs_to_args,
+)
+from internlm.solver.activation_checkpoint import activation_checkpoint
+from internlm.utils.logger import get_logger
+from internlm.utils.storage_manager import get_fns, llm_load, llm_save
+from transformers.modeling_utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ shard_checkpoint,
+)
+
+try:
+ from flash_attn.modules.mlp import ParallelFusedMLP
+except ImportError:
+ pass
+
+internlm_accelerator = get_accelerator()
+logger = get_logger(__file__)
+
+
+class GemmaDecoder(nn.Module):
+ """
+ 1D Packed Flash Llama Layer.
+
+ Args:
+ hidden_size (int): The hidden size of model. 768 by default.
+ num_attention_heads (int): The number of attention heads. 12 by default.
+ head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default.
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
+ drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
+ dtype (torch.dtype): Type of data. torch.float by default.
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
+ layer_idx (int): The index of current layer. 0 by default.
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
+ device (Optional[Union[str, torch.device]]): The device will be used.
+ add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default.
+ use_glu (bool): Whether to use glu. True by default.
+ use_swiglu (bool): Whether to use swiglu. True by default.
+ attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
+ attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
+ ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
+ otherwise init fc1 weight in ffn. 0.02 by default,
+ ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
+ init_type (str): Initialization type. Use uniform or normal. "normal" by default,
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
+ tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"],
+ "mtp" by default.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ num_kv_attention_heads: int = 12,
+ head_dim: int = None,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0,
+ drop_rate: float = 0.0,
+ max_position_embeddings: int = 2048,
+ dtype: torch.dtype = torch.float,
+ layer_norm_epsilon: float = 1e-6,
+ checkpoint: bool = False,
+ layer_idx: int = 0,
+ use_dynamic_ntk_rope: bool = False,
+ residual_in_fp32: bool = False,
+ device: Optional[torch.device] = None,
+ apply_post_layer_norm: bool = False,
+ fused_dropout_add_ln: bool = True,
+ no_bias: bool = False,
+ norm_type: str = "rmsnorm",
+ qk_interleaved: bool = False,
+ add_unit_offset: bool = False,
+ dropout_selective_checkpoint: bool = True,
+ use_scaled_init: bool = True,
+ use_glu: bool = True,
+ use_swiglu: bool = True,
+ attn_wqkv_init_std: float = 0.02,
+ attn_other_init_std: float = 0.02,
+ ffn_uplayer_init_std: float = 0.02,
+ ffn_other_init_std: float = 0.02,
+ init_type: str = "normal",
+ rope_base: int = 10000,
+ mlp_layer_fusion: bool = False,
+ multiple_of: int = 256,
+ tp_mode: str = "mtp",
+ ):
+ super().__init__()
+ self.checkpoint = checkpoint
+ # dropout selective checkpoint can only be enabled when checkpoint is disabled.
+ self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
+ self.layer_idx = layer_idx
+ self.prenorm = not apply_post_layer_norm
+ assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
+ self.fused_dropout_add_ln = fused_dropout_add_ln
+ self.attn_wqkv_init_std = attn_wqkv_init_std
+ self.attn_other_init_std = attn_other_init_std
+ self.ffn_uplayer_init_std = ffn_uplayer_init_std
+ self.ffn_other_init_std = ffn_other_init_std
+
+ if not head_dim:
+ head_dim = hidden_size // num_attention_heads
+
+ self.attention = GQA(
+ embed_dim=hidden_size,
+ num_heads=num_attention_heads,
+ num_kv_heads=num_kv_attention_heads,
+ head_dim=head_dim,
+ dropout=attn_drop_rate,
+ max_position_embeddings=max_position_embeddings,
+ softmax_scale=1 / math.sqrt(head_dim),
+ causal=True,
+ layer_idx=layer_idx,
+ use_dynamic_ntk_rope=use_dynamic_ntk_rope,
+ rotary_emb_dim=head_dim,
+ rotary_emb_scale_base=0,
+ device=device,
+ dtype=dtype,
+ qk_interleaved=qk_interleaved,
+ bias=not no_bias,
+ rope_base=rope_base,
+ enable_qkv_fusion=False,
+ )
+
+ self.dropout1 = nn.Dropout(drop_rate)
+ self.dropout2 = nn.Dropout(drop_rate)
+ self.attention_norm = new_layer_norm(
+ norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset
+ )
+ self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset)
+
+ sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)
+ parallel_mode = ParallelMode.WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR
+
+ if use_glu:
+ self.feed_forward = new_feed_forward(
+ hidden_size,
+ int(hidden_size * mlp_ratio),
+ out_features=hidden_size,
+ bias=False,
+ device=device,
+ dtype=dtype,
+ mlp_layer_fusion=mlp_layer_fusion,
+ multiple_of=multiple_of,
+ activation_type="swiglu" if use_swiglu else "gelu",
+ )
+ else:
+ self.feed_forward = ParallelFusedMLP(
+ hidden_size,
+ int(hidden_size * mlp_ratio),
+ out_features=hidden_size,
+ activation="gelu_approx",
+ process_group=gpc.get_group(parallel_mode),
+ bias1=False,
+ bias2=False,
+ sequence_parallel=sequence_parallel,
+ checkpoint_lvl=0,
+ heuristic="auto",
+ device=device,
+ dtype=dtype,
+ )
+
+ self.use_glu = use_glu
+ self.use_swiglu = use_swiglu
+ self.use_scaled_init = use_scaled_init
+ self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
+ self.return_residual = False
+
+ if init_type == "normal":
+ self.init_func = normal_
+ self.scaled_init_func = scaled_init_method_normal
+ else:
+ self.init_func = uniform_
+ self.scaled_init_func = scaled_init_method_uniform
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ with torch.no_grad():
+ for name, param in self.attention.named_parameters():
+ if param.ndim == 1:
+ param.data.zero_()
+ elif "wq" in name or "wk" in name or "wv" in name:
+ self.init_func(std=self.attn_wqkv_init_std)(param.data)
+ elif self.use_scaled_init: # wo
+ self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(std=self.attn_other_init_std)(param.data)
+
+ for name, param in self.feed_forward.named_parameters():
+ if self.use_glu:
+ if self.use_scaled_init and "w2" in name:
+ self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(
+ std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
+ )(param.data)
+ else:
+ if self.use_scaled_init and "fc1" not in name:
+ self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
+ param.data
+ )
+
+ def forward(self, hidden_states, residual=None, **kwargs):
+ if self.checkpoint and self.training:
+ args = convert_attn_kwargs_to_args(kwargs)
+ return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
+ else:
+ return self._forward(hidden_states, residual, **kwargs)
+
+ def _forward(self, hidden_states, residual, *args, **kwargs):
+ r"""Pass the input through the encoder layer.
+
+ Args:
+ hidden_states: the sequence to the encoder layer (required).
+ residual: hidden_states = Attn/MLP(LN(residual))
+ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
+ indexes: the length of index is same as hidden states, which stand for the current position
+ """
+ if self.prenorm:
+
+ def _dropout_and_norm_attn(_residual, _hidden_states):
+ _dropped = self.dropout1(_hidden_states)
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
+ _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
+
+ return _residual, _hidden_states
+
+ if self.dropout_selective_checkpoint:
+ residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
+ else:
+ residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
+
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+
+ mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
+ hidden_states = self.attention(hidden_states, **mixer_kwargs)
+
+ if not isinstance(self.feed_forward, nn.Identity):
+ if not self.fused_dropout_add_ln:
+
+ def _dropout_and_norm_ffn(_residual, _hidden_states):
+ _dropped = self.dropout2(_hidden_states)
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
+ _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
+
+ return _residual, _hidden_states
+
+ if self.dropout_selective_checkpoint:
+ residual, hidden_states = activation_checkpoint(
+ _dropout_and_norm_ffn, False, residual, hidden_states
+ )
+ else:
+ residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
+
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ hidden_states = self.feed_forward(hidden_states)
+
+ return hidden_states + residual
+ else:
+ assert residual is None
+
+ mixer_out = self.attention(hidden_states, **kwargs)
+ if self.return_residual: # mixer out is actually a pair here
+ mixer_out, hidden_states = mixer_out
+ hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
+ dtype=self.attention_norm.weight.dtype
+ )
+ if not isinstance(self.feed_forward, nn.Identity):
+ mlp_out = self.feed_forward(hidden_states)
+ if self.return_residual: # mlp out is actually a pair here
+ mlp_out, hidden_states = mlp_out
+ hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
+ dtype=self.ffn_norm.weight.dtype
+ )
+ return hidden_states
+
+
+class Gemma(BaseModel):
+ """
+ 1D Packed Flash Llama.
+
+ Args:
+ num_layers (int): The number of layer. 12 by default.
+ hidden_size (int): The size of hidden state. 768 by default.
+ num_attention_heads (int): The number of attention head. 12 by default.
+ head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default.
+ vocab_size (int): The size of vocabulary. 50304 by default.
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
+ attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
+ drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
+ dtype (torch.dtype): The type of data. torch.float by default.
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
+ checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
+ of layers. 1.0 by default.
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
+ first (bool): Whether input embedding layer or not. False by default.
+ last (bool): Whether output embedding layer or not. False by default.
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
+ start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
+ device (Optional[Union[str, torch.device]]): The device will be used. None by default.
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
+ add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default.
+ use_glu (bool): Whether to use glu. True by default.
+ use_swiglu (bool): Whether to use swiglu. True by default.
+ embedding_init_std (float): std used to init embedding weight. 0.02 by default,
+ attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
+ attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
+ ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
+ otherwise init fc1 weight in ffn. 0.02 by default,
+ ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
+ out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
+ init_type (str): Initialization type. Use uniform or normal. "normal" by default,
+ extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
+ """
+
+ def __init__(
+ self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ num_kv_attention_heads: int = 12,
+ head_dim: int = None,
+ vocab_size: int = 50304,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.0,
+ drop_rate: float = 0.0,
+ max_position_embeddings: int = 2048,
+ dtype: torch.dtype = torch.float,
+ checkpoint: float = 1.0,
+ layer_norm_epsilon: float = 1e-5,
+ first: bool = False,
+ last: bool = False,
+ embed_grad_scale: float = 0.1,
+ parallel_output: bool = True,
+ start_layer_idx: int = 0,
+ use_dynamic_ntk_rope: bool = False,
+ device: Optional[torch.device] = None,
+ apply_post_layer_norm=False,
+ no_bias=False,
+ residual_in_fp32: bool = False,
+ norm_type: str = "rmsnorm",
+ qk_interleaved: bool = False,
+ add_unit_offset: bool = False,
+ is_reward: bool = False,
+ dropout_selective_checkpoint: bool = True,
+ use_scaled_init: bool = True,
+ use_glu: bool = True,
+ use_swiglu: bool = False,
+ embedding_init_std: float = 0.02,
+ attn_wqkv_init_std: float = 0.02,
+ attn_other_init_std: float = 0.02,
+ ffn_uplayer_init_std: float = 0.02,
+ ffn_other_init_std: float = 0.02,
+ out_head_init_std: float = 0.02,
+ init_type: str = "normal",
+ extra_pred_tokens: int = 0,
+ rope_base: int = 10000,
+ norm_head: bool = False,
+ mlp_layer_fusion: bool = False,
+ multiple_of: int = 256,
+ ):
+ super().__init__()
+
+ checkpoint_layer_num = int(num_layers * checkpoint)
+ self.hidden_size = hidden_size
+ self.embed_grad_scale = embed_grad_scale
+ self.parallel_output = parallel_output
+ self.tp_mode = "mtp"
+ if isinstance(gpc.config.parallel["tensor"], dict):
+ self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp")
+
+ if first:
+ self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
+ for _, param in self.embed_tokens.named_parameters():
+ if init_type == "normal":
+ normal_(std=embedding_init_std)(param)
+ else:
+ uniform_(std=embedding_init_std)(param)
+
+ self.layers = nn.ModuleList(
+ [
+ GemmaDecoder(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_kv_attention_heads=num_kv_attention_heads,
+ head_dim=head_dim,
+ mlp_ratio=mlp_ratio,
+ attn_drop_rate=attn_drop_rate,
+ drop_rate=drop_rate,
+ max_position_embeddings=max_position_embeddings,
+ dtype=dtype,
+ layer_norm_epsilon=layer_norm_epsilon,
+ checkpoint=lid < checkpoint_layer_num,
+ layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
+ use_dynamic_ntk_rope=use_dynamic_ntk_rope,
+ residual_in_fp32=residual_in_fp32,
+ device=device,
+ apply_post_layer_norm=apply_post_layer_norm,
+ fused_dropout_add_ln=False,
+ no_bias=no_bias,
+ norm_type=norm_type,
+ add_unit_offset=add_unit_offset,
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
+ use_scaled_init=use_scaled_init,
+ use_glu=use_glu,
+ use_swiglu=use_swiglu,
+ qk_interleaved=qk_interleaved,
+ attn_wqkv_init_std=attn_wqkv_init_std,
+ attn_other_init_std=attn_other_init_std,
+ ffn_uplayer_init_std=ffn_uplayer_init_std,
+ ffn_other_init_std=ffn_other_init_std,
+ init_type=init_type,
+ rope_base=rope_base,
+ mlp_layer_fusion=mlp_layer_fusion,
+ multiple_of=multiple_of,
+ tp_mode=self.tp_mode,
+ )
+ for lid in range(num_layers)
+ ]
+ )
+
+ if last:
+ if not apply_post_layer_norm:
+ self.norm = new_layer_norm(
+ norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset
+ )
+
+ self.output = new_linear(
+ name="output",
+ in_features=hidden_size,
+ out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
+ bias=False,
+ device=device,
+ is_reward=is_reward,
+ dtype=dtype,
+ weight_scale=embed_grad_scale,
+ norm_head=norm_head,
+ )
+ for _, param in self.output.named_parameters():
+ if init_type == "normal":
+ normal_(std=out_head_init_std)(param)
+ else:
+ uniform_(std=out_head_init_std)(param)
+
+ if extra_pred_tokens > 0:
+ self.extra_pred_tokens = extra_pred_tokens
+ assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
+ self.extra_outputs = nn.ModuleList(
+ [
+ new_linear(
+ name="output",
+ in_features=hidden_size,
+ out_features=vocab_size,
+ bias=False,
+ device=device,
+ is_reward=is_reward,
+ dtype=dtype,
+ weight_scale=embed_grad_scale,
+ norm_head=norm_head,
+ )
+ for _ in range(self.extra_pred_tokens)
+ ]
+ )
+ for _, param in self.extra_outputs.named_parameters():
+ if init_type == "normal":
+ normal_(std=out_head_init_std)(param)
+ else:
+ uniform_(std=out_head_init_std)(param)
+
+ def forward(self, hidden_states=None, input_ids=None, **kwargs):
+ # attention_mask: compute attention on the places where the value is 1
+ if hasattr(self, "embed_tokens"):
+ hidden_states = self.embed_tokens(input_ids)
+ if self.embed_grad_scale != 1:
+ hidden_states = (
+ self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
+ )
+ hidden_states = hidden_states * (self.hidden_size**0.5)
+
+ for _, block in enumerate(self.layers):
+ hidden_states = block(hidden_states, residual=None, **kwargs)
+
+ if hasattr(self, "norm"):
+ hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
+ if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
+ extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
+ else:
+ extra_hidden_states_list = None
+ if hasattr(self, "output"):
+ hidden_states = self.output(hidden_states)
+
+ if extra_hidden_states_list is not None:
+ return (hidden_states, extra_hidden_states_list)
+
+ return hidden_states
+
+ @staticmethod
+ def load_hf_weights(folder: str, model: nn.Module) -> None:
+ assert folder is not None, "Please specify the folder of the pretrained model"
+ if gpc.is_rank_for_log():
+ logger.info(f"Loading pretrained model from {folder}")
+
+ fns = get_fns(folder)
+ model_fns = [
+ os.path.join(folder, fn)
+ for fn in fns
+ if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
+ or (fn.endswith(".safetensors") and fn.startswith("model"))
+ ]
+ model_fns.sort()
+
+ state_dict = {}
+ for model_fn in model_fns:
+ state_dict.update(llm_load(model_fn, map_location="cpu"))
+
+ tp_size = gpc.get_world_size(ParallelMode.TENSOR)
+ tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
+ wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
+ wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
+ tp_mode = gpc.config.parallel.tensor["mode"]
+ split_size = wp_size if tp_mode == "isp" else tp_size
+ local_rank = wp_rank if tp_mode == "isp" else tp_rank
+ row_dim = 0 if tp_mode == "isp" else 1
+ if gpc.config.model.get("embed_split_hidden", True):
+ embed_concat_dim = 1
+ else:
+ embed_concat_dim = 0
+
+ new_state_dict = {}
+
+ # embedding
+ if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
+ new_state_dict["embed_tokens.weight"] = torch.chunk(
+ state_dict.get("model.embed_tokens.weight"),
+ split_size,
+ dim=embed_concat_dim,
+ )[local_rank]
+
+ for idx, i in enumerate(range(model.first_layer, model.last_layer)):
+ layer_ids = i
+
+ # attn
+ state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
+ split_size,
+ dim=row_dim,
+ )[local_rank]
+
+ # ffn
+ state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
+ split_size,
+ dim=row_dim,
+ )[local_rank]
+
+ # attn norm
+ state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
+ f"model.layers.{layer_ids}.input_layernorm.weight"
+ )
+ # ffn norm
+ state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
+ f"model.layers.{layer_ids}.post_attention_layernorm.weight"
+ )
+
+ # replace value within decoder layer
+ for name in list(state_dict.keys()):
+ if name.startswith(f"layers.{i}"):
+ new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)
+
+ # output
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
+ if "lm_head.weight" in state_dict:
+ new_state_dict["output.weight"] = torch.chunk(
+ state_dict.pop("lm_head.weight"), # we do not tie lm head with embedding
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict.pop("model.embed_tokens.weight")
+ else:
+ new_state_dict["output.weight"] = torch.chunk(
+ # gemma model ties lm head with embedding in transformers implementation
+ state_dict.pop("model.embed_tokens.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")
+
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
+ if len(state_dict) > 0:
+ logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")
+
+ if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
+ logger.info(
+ f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
+ f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
+ )
+
+ internlm_accelerator.empty_cache()
+
+ @staticmethod
+ def convert_internevo2hf_weights(src: str, tgt: str) -> None:
+ model_config = gpc.config.model
+ tp_mode = gpc.config.parallel.tensor["mode"]
+ row_dim = 0 if tp_mode == "isp" else 1
+
+ # load states
+ states, num_shards = Gemma.load_sharded_states(src)
+
+ # convert state_dict
+ state_dict = {}
+ embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
+ for layer_i in tqdm(range(model_config["num_layers"])):
+ # attn norm, mlp norm
+ state_dict.update(
+ {
+ f"model.layers.{layer_i}.input_layernorm.weight": states[0][
+ f"layers.{layer_i}.attention_norm.weight"
+ ].clone(),
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
+ f"layers.{layer_i}.ffn_norm.weight"
+ ].clone(),
+ }
+ )
+ # attn wqkv weight and bias
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)],
+ dim=0,
+ )
+ # attn wo weight
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim
+ )
+
+ # mlp
+ state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
+ )
+ state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
+ )
+ state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
+ )
+
+ # embedding, head
+ for embedding_key in embedding_key_list:
+ if embedding_key in states[0]:
+ break
+ if embedding_key is None:
+ raise KeyError("Cannot find embedding key!")
+ if model_config["embed_split_hidden"]:
+ embed_concat_dim = 1
+ tok_emb_list = [states[i][embedding_key] for i in range(num_shards)]
+ else:
+ embed_concat_dim = 0
+ _, size_1 = states[0][embedding_key].shape
+ embdim_pertp = size_1 // num_shards
+ tok_emb_list = [
+ torch.concat(
+ [
+ states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)]
+ for tp in range(num_shards)
+ ],
+ dim=0,
+ )
+ for local_rank in range(num_shards)
+ ]
+ state_dict.update(
+ {
+ "model.norm.weight": states[0]["norm.weight"],
+ "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim),
+ "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
+ },
+ )
+
+ # save state_dict to hf format
+ shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
+ for shard_file, shard in shards.items():
+ llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
+ if index is not None:
+ # Save the index as well
+ llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
diff --git a/internlm/model/modeling_qwen2.py b/internlm/model/modeling_qwen2.py
new file mode 100644
index 00000000..d3700baa
--- /dev/null
+++ b/internlm/model/modeling_qwen2.py
@@ -0,0 +1,752 @@
+# Copyright (c) InternLM. All rights reserved.
+import math
+import os
+from typing import Optional
+
+import torch
+from torch import nn
+from tqdm import tqdm
+
+from internlm.accelerator import get_accelerator
+from internlm.core.context import ParallelMode
+from internlm.core.context.parallel_context import global_context as gpc
+from internlm.initialize.initialize_tensor import (
+ normal_,
+ scaled_init_method_normal,
+ scaled_init_method_uniform,
+ uniform_,
+)
+from internlm.model.base_model import BaseModel
+from internlm.model.modules.embedding import Embedding1D
+from internlm.model.modules.linear import new_linear
+from internlm.model.modules.mha import SWA
+from internlm.model.modules.mlp import new_feed_forward
+from internlm.model.modules.norm import new_layer_norm
+from internlm.model.utils import (
+ convert_attn_args_to_kwargs,
+ convert_attn_kwargs_to_args,
+)
+from internlm.solver.activation_checkpoint import activation_checkpoint
+from internlm.utils.logger import get_logger
+from internlm.utils.storage_manager import get_fns, llm_load, llm_save
+from transformers.modeling_utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ shard_checkpoint,
+)
+
+internlm_accelerator = get_accelerator()
+logger = get_logger(__file__)
+
+
+class Qwen2Decoder(nn.Module):
+ """
+ 1D Packed Flash Qwen Layer.
+
+ Args:
+ hidden_size (int): The hidden size of model. 768 by default.
+ num_attention_heads (int): The number of attention heads. 12 by default.
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
+ drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
+ dtype (torch.dtype): Type of data. torch.float by default.
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
+ layer_idx (int): The index of current layer. 0 by default.
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
+ device (Optional[Union[str, torch.device]]): The device will be used.
+ norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
+ attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
+ attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
+ ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
+ otherwise init fc1 weight in ffn. 0.02 by default,
+ ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
+ init_type (str): Initialization type. Use uniform or normal. "normal" by default,
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ num_kv_attention_heads: int = 12,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0,
+ drop_rate: float = 0.0,
+ max_position_embeddings: int = 2048,
+ dtype: torch.dtype = torch.float,
+ layer_norm_epsilon: float = 1e-6,
+ checkpoint: bool = False,
+ layer_idx: int = 0,
+ use_dynamic_ntk_rope: bool = False,
+ residual_in_fp32: bool = False,
+ device: Optional[torch.device] = None,
+ apply_post_layer_norm: bool = False,
+ fused_dropout_add_ln: bool = True,
+ qkv_bias=True,
+ o_bias=False,
+ mlp_bias=False,
+ norm_type: str = "rmsnorm",
+ qk_interleaved: bool = False,
+ dropout_selective_checkpoint: bool = True,
+ use_scaled_init: bool = True,
+ use_swiglu: bool = True,
+ attn_wqkv_init_std: float = 0.02,
+ attn_other_init_std: float = 0.02,
+ ffn_uplayer_init_std: float = 0.02,
+ ffn_other_init_std: float = 0.02,
+ init_type: str = "normal",
+ rope_type: str = "normal",
+ rope_base: int = 10000,
+ rope_scaling_factor: float = 1.0,
+ use_sliding_window: bool = False,
+ sliding_window: int = None,
+ mlp_layer_fusion: bool = False,
+ multiple_of: int = 256,
+ scale_attn_weights: bool = False, # Qwen1
+ use_logn_attn: bool = False, # Qwen1
+ ):
+ super().__init__()
+ self.checkpoint = checkpoint
+ # dropout selective checkpoint can only be enabled when checkpoint is disabled.
+ self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
+ self.layer_idx = layer_idx
+ self.prenorm = not apply_post_layer_norm
+ assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
+ self.fused_dropout_add_ln = fused_dropout_add_ln
+ self.attn_wqkv_init_std = attn_wqkv_init_std
+ self.attn_other_init_std = attn_other_init_std
+ self.ffn_uplayer_init_std = ffn_uplayer_init_std
+ self.ffn_other_init_std = ffn_other_init_std
+
+ head_dim = hidden_size // num_attention_heads
+
+ if scale_attn_weights:
+ softmax_scale = None
+ else:
+ softmax_scale = 1 / math.sqrt(head_dim)
+ self.attention = SWA(
+ embed_dim=hidden_size,
+ num_heads=num_attention_heads,
+ num_kv_heads=num_kv_attention_heads,
+ dropout=attn_drop_rate,
+ max_position_embeddings=max_position_embeddings,
+ softmax_scale=softmax_scale,
+ causal=True,
+ layer_idx=layer_idx,
+ use_dynamic_ntk_rope=use_dynamic_ntk_rope,
+ rotary_emb_dim=head_dim,
+ rotary_emb_scale_base=0,
+ device=device,
+ dtype=dtype,
+ qk_interleaved=qk_interleaved,
+ qkv_bias=qkv_bias,
+ o_bias=o_bias,
+ rope_type=rope_type,
+ rope_base=rope_base,
+ rope_scaling_factor=rope_scaling_factor,
+ use_sliding_window=use_sliding_window,
+ sliding_window=sliding_window,
+ use_logn_attn=use_logn_attn,
+ )
+
+ self.dropout1 = nn.Dropout(drop_rate)
+ self.dropout2 = nn.Dropout(drop_rate)
+ self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
+ self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
+
+ self.feed_forward = new_feed_forward(
+ hidden_size,
+ int(hidden_size * mlp_ratio),
+ out_features=hidden_size,
+ bias=mlp_bias,
+ device=device,
+ dtype=dtype,
+ mlp_layer_fusion=mlp_layer_fusion,
+ multiple_of=multiple_of,
+ activation_type="swiglu" if use_swiglu else "gelu",
+ )
+
+ self.use_swiglu = use_swiglu
+ self.use_scaled_init = use_scaled_init
+ self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
+ self.return_residual = False
+
+ if init_type == "normal":
+ self.init_func = normal_
+ self.scaled_init_func = scaled_init_method_normal
+ else:
+ self.init_func = uniform_
+ self.scaled_init_func = scaled_init_method_uniform
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ with torch.no_grad():
+ for name, param in self.attention.named_parameters():
+ if param.ndim == 1:
+ param.data.zero_()
+ elif "wq" in name or "wk" in name or "wv" in name:
+ self.init_func(std=self.attn_wqkv_init_std)(param.data)
+ elif self.use_scaled_init: # wo
+ self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(std=self.attn_other_init_std)(param.data)
+
+ for name, param in self.feed_forward.named_parameters():
+ if self.use_swiglu:
+ if self.use_scaled_init and "w2" in name:
+ self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ # candidate: w1, w3, fused_w1_w3
+ self.init_func(
+ std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
+ )(param.data)
+ else:
+ if self.use_scaled_init and "fc1" not in name:
+ self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
+ else:
+ self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
+ param.data
+ )
+
+ def forward(self, hidden_states, residual=None, **kwargs):
+ if self.checkpoint and self.training:
+ args = convert_attn_kwargs_to_args(kwargs)
+ return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
+ else:
+ return self._forward(hidden_states, residual, **kwargs)
+
+ def _forward(self, hidden_states, residual, *args, **kwargs):
+ r"""Pass the input through the encoder layer.
+
+ Args:
+ hidden_states: the sequence to the encoder layer (required).
+ residual: hidden_states = Attn/MLP(LN(residual))
+ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
+ indexes: the length of index is same as hidden states, which stand for the current position
+ """
+ if self.prenorm:
+
+ def _dropout_and_norm_attn(_residual, _hidden_states):
+ _dropped = self.dropout1(_hidden_states)
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
+ _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
+
+ return _residual, _hidden_states
+
+ if self.dropout_selective_checkpoint:
+ residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
+ else:
+ residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
+
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+
+ mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
+ hidden_states = self.attention(hidden_states, **mixer_kwargs)
+
+ if not isinstance(self.feed_forward, nn.Identity):
+ if not self.fused_dropout_add_ln:
+
+ def _dropout_and_norm_ffn(_residual, _hidden_states):
+ _dropped = self.dropout2(_hidden_states)
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
+ _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
+
+ return _residual, _hidden_states
+
+ if self.dropout_selective_checkpoint:
+ residual, hidden_states = activation_checkpoint(
+ _dropout_and_norm_ffn, False, residual, hidden_states
+ )
+ else:
+ residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
+
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ hidden_states = self.feed_forward(hidden_states)
+
+ return hidden_states + residual
+ else:
+ assert residual is None
+
+ mixer_out = self.attention(hidden_states, **kwargs)
+ if self.return_residual: # mixer out is actually a pair here
+ mixer_out, hidden_states = mixer_out
+ hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
+ dtype=self.attention_norm.weight.dtype
+ )
+ if not isinstance(self.feed_forward, nn.Identity):
+ mlp_out = self.feed_forward(hidden_states)
+ if self.return_residual: # mlp out is actually a pair here
+ mlp_out, hidden_states = mlp_out
+ hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
+ dtype=self.ffn_norm.weight.dtype
+ )
+ return hidden_states
+
+
+class Qwen2(BaseModel):
+ """
+ 1D Packed Flash Qwen.
+
+ Args:
+ num_layers (int): The number of layer. 12 by default.
+ hidden_size (int): The size of hidden state. 768 by default.
+ num_attention_heads (int): The number of attention head. 12 by default.
+ vocab_size (int): The size of vocabulary. 50304 by default.
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
+ attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
+ drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
+ dtype (torch.dtype): The type of data. torch.float by default.
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
+ first (bool): Whether input embedding layer or not. False by default.
+ last (bool): Whether output embedding layer or not. False by default.
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
+ start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
+ device (Optional[Union[str, torch.device]]): The device will be used. None by default.
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
+ norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
+ embedding_init_std (float): std used to init embedding weight. 0.02 by default,
+ attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
+ attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
+ ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
+ otherwise init fc1 weight in ffn. 0.02 by default,
+ ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
+ out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
+ init_type (str): Initialization type. Use uniform or normal. "normal" by default,
+ extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
+ """
+
+ def __init__(
+ self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ num_kv_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.0,
+ drop_rate: float = 0.0,
+ max_position_embeddings: int = 2048,
+ dtype: torch.dtype = torch.float,
+ checkpoint: float = 1.0,
+ layer_norm_epsilon: float = 1e-5,
+ first: bool = False,
+ last: bool = False,
+ embed_grad_scale: float = 0.1,
+ parallel_output: bool = True,
+ start_layer_idx: int = 0,
+ use_dynamic_ntk_rope: bool = False,
+ device: Optional[torch.device] = None,
+ apply_post_layer_norm=False,
+ qkv_bias=True,
+ o_bias=False,
+ mlp_bias=False,
+ residual_in_fp32: bool = False,
+ norm_type: str = "rmsnorm",
+ qk_interleaved: bool = False,
+ is_reward: bool = False,
+ dropout_selective_checkpoint: bool = True,
+ use_scaled_init: bool = True,
+ use_swiglu: bool = True,
+ embedding_init_std: float = 0.02,
+ attn_wqkv_init_std: float = 0.02,
+ attn_other_init_std: float = 0.02,
+ ffn_uplayer_init_std: float = 0.02,
+ ffn_other_init_std: float = 0.02,
+ out_head_init_std: float = 0.02,
+ init_type: str = "normal",
+ extra_pred_tokens: int = 0,
+ rope_type: str = "normal",
+ rope_base: int = 10000,
+ rope_scaling_factor: float = 1.0,
+ use_sliding_window: bool = False,
+ max_window_layers: int = 0,
+ sliding_window: int = None,
+ mlp_layer_fusion: bool = False,
+ multiple_of: int = 256,
+ scale_attn_weights: bool = False, # Qwen1
+ use_logn_attn: bool = False, # Qwen1
+ ):
+ super().__init__()
+
+ self.embed_grad_scale = embed_grad_scale
+
+ checkpoint_layer_num = int(num_layers * checkpoint)
+
+ if first:
+ self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
+ for _, param in self.embed_tokens.named_parameters():
+ if init_type == "normal":
+ normal_(std=embedding_init_std)(param)
+ else:
+ uniform_(std=embedding_init_std)(param)
+
+ self.layers = nn.ModuleList(
+ [
+ Qwen2Decoder(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_kv_attention_heads=num_kv_attention_heads,
+ mlp_ratio=mlp_ratio,
+ attn_drop_rate=attn_drop_rate,
+ drop_rate=drop_rate,
+ dtype=dtype,
+ layer_norm_epsilon=layer_norm_epsilon,
+ checkpoint=lid < checkpoint_layer_num,
+ layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
+ use_dynamic_ntk_rope=use_dynamic_ntk_rope,
+ residual_in_fp32=residual_in_fp32,
+ device=device,
+ apply_post_layer_norm=apply_post_layer_norm,
+ fused_dropout_add_ln=False,
+ qkv_bias=qkv_bias,
+ o_bias=o_bias,
+ mlp_bias=mlp_bias,
+ norm_type=norm_type,
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
+ use_scaled_init=use_scaled_init,
+ use_swiglu=use_swiglu,
+ qk_interleaved=qk_interleaved,
+ attn_wqkv_init_std=attn_wqkv_init_std,
+ attn_other_init_std=attn_other_init_std,
+ ffn_uplayer_init_std=ffn_uplayer_init_std,
+ ffn_other_init_std=ffn_other_init_std,
+ init_type=init_type,
+ rope_type=rope_type,
+ rope_base=rope_base,
+ rope_scaling_factor=rope_scaling_factor,
+ use_sliding_window=use_sliding_window and lid >= max_window_layers,
+ sliding_window=sliding_window,
+ mlp_layer_fusion=mlp_layer_fusion,
+ multiple_of=multiple_of,
+ max_position_embeddings=max_position_embeddings,
+ scale_attn_weights=scale_attn_weights,
+ use_logn_attn=use_logn_attn,
+ )
+ for lid in range(num_layers)
+ ]
+ )
+
+ if last:
+ if not apply_post_layer_norm:
+ self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
+
+ self.output = new_linear(
+ name="output",
+ in_features=hidden_size,
+ out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
+ bias=False,
+ device=device,
+ dtype=dtype,
+ is_reward=is_reward,
+ weight_scale=embed_grad_scale,
+ )
+
+ for _, param in self.output.named_parameters():
+ if init_type == "normal":
+ normal_(std=out_head_init_std)(param)
+ else:
+ uniform_(std=out_head_init_std)(param)
+
+ if extra_pred_tokens > 0:
+ self.extra_pred_tokens = extra_pred_tokens
+ assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
+ self.extra_outputs = nn.ModuleList(
+ [
+ new_linear(
+ name="output",
+ in_features=hidden_size,
+ out_features=vocab_size,
+ bias=False,
+ device=device,
+ dtype=dtype,
+ is_reward=is_reward,
+ weight_scale=embed_grad_scale,
+ )
+ for _ in range(self.extra_pred_tokens)
+ ]
+ )
+ for _, param in self.extra_outputs.named_parameters():
+ if init_type == "normal":
+ normal_(std=out_head_init_std)(param)
+ else:
+ uniform_(std=out_head_init_std)(param)
+
+ self.parallel_output = parallel_output
+
+ def forward(self, hidden_states=None, input_ids=None, **kwargs):
+ # attention_mask: compute attention on the places where the value is 1
+ if hasattr(self, "embed_tokens"):
+ hidden_states = self.embed_tokens(input_ids)
+ if self.embed_grad_scale != 1:
+ hidden_states = (
+ self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
+ )
+
+ for _, block in enumerate(self.layers):
+ hidden_states = block(
+ hidden_states,
+ residual=None,
+ **kwargs,
+ )
+
+ if hasattr(self, "norm"):
+ hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
+ if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
+ extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
+ else:
+ extra_hidden_states_list = None
+ if hasattr(self, "output"):
+ hidden_states = self.output(hidden_states)
+
+ if extra_hidden_states_list is not None:
+ return (hidden_states, extra_hidden_states_list)
+
+ return hidden_states
+
+ @staticmethod
+ def load_hf_weights(folder: str, model: nn.Module) -> None:
+ assert folder is not None, "Please specify the folder of the pretrained model"
+ if gpc.is_rank_for_log():
+ logger.info(f"Loading pretrained model from {folder}")
+
+ fns = get_fns(folder)
+ model_fns = [
+ os.path.join(folder, fn)
+ for fn in fns
+ if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
+ or (fn.endswith(".safetensors") and fn.startswith("model"))
+ ]
+ model_fns.sort()
+
+ state_dict = {}
+ for model_fn in model_fns:
+ state_dict.update(llm_load(model_fn, map_location="cpu"))
+
+ tp_size = gpc.get_world_size(ParallelMode.TENSOR)
+ tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
+ wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
+ wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
+ tp_mode = gpc.config.parallel.tensor["mode"]
+ split_size = wp_size if tp_mode == "isp" else tp_size
+ local_rank = wp_rank if tp_mode == "isp" else tp_rank
+ row_dim = 0 if tp_mode == "isp" else 1
+ if gpc.config.model.get("embed_split_hidden", True):
+ embed_concat_dim = 1
+ else:
+ embed_concat_dim = 0
+
+ new_state_dict = {}
+
+ # embedding
+ if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
+ new_state_dict["embed_tokens.weight"] = torch.chunk(
+ state_dict.pop("model.embed_tokens.weight"),
+ split_size,
+ dim=embed_concat_dim,
+ )[local_rank]
+
+ for idx, i in enumerate(range(model.first_layer, model.last_layer)):
+ layer_ids = i
+
+ # attn
+ state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wq.bias"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.bias"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wk.bias"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.bias"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wv.bias"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.bias"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
+ split_size,
+ dim=row_dim,
+ )[local_rank]
+
+ # ffn
+ state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
+ state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
+ split_size,
+ dim=row_dim,
+ )[local_rank]
+
+ # attn norm
+ state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
+ f"model.layers.{layer_ids}.input_layernorm.weight"
+ )
+ # ffn norm
+ state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
+ f"model.layers.{layer_ids}.post_attention_layernorm.weight"
+ )
+
+ # replace value within decoder layer
+ for name in list(state_dict.keys()):
+ if name.startswith(f"layers.{i}"):
+ new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)
+
+ # output
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
+ new_state_dict["output.weight"] = torch.chunk(
+ state_dict.pop("lm_head.weight"),
+ split_size,
+ dim=0,
+ )[local_rank]
+ new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")
+
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
+ if len(state_dict) > 0:
+ logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")
+
+ if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
+ logger.info(
+ f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
+ f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
+ )
+
+ internlm_accelerator.empty_cache()
+
+ @staticmethod
+ def convert_internevo2hf_weights(src: str, tgt: str) -> None:
+ model_config = gpc.config.model
+ tp_mode = gpc.config.parallel.tensor["mode"]
+ row_dim = 0 if tp_mode == "isp" else 1
+
+ # load states
+ states, num_shards = Qwen2.load_sharded_states(src)
+
+ # convert state_dict
+ state_dict = {}
+ embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
+ for layer_i in tqdm(range(model_config["num_layers"])):
+ # attn norm, mlp norm
+ state_dict.update(
+ {
+ f"model.layers.{layer_i}.input_layernorm.weight": states[0][
+ f"layers.{layer_i}.attention_norm.weight"
+ ].clone(),
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
+ f"layers.{layer_i}.ffn_norm.weight"
+ ].clone(),
+ }
+ )
+ # attn wqkv weight and bias
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wq.bias"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wk.bias"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)],
+ dim=0,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wv.bias"] for i in range(num_shards)],
+ dim=0,
+ )
+ # attn wo weight
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim
+ )
+
+ # mlp
+ state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
+ )
+ state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
+ )
+ state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
+ [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
+ )
+
+ # embedding, head
+ for embedding_key in embedding_key_list:
+ if embedding_key in states[0]:
+ break
+ if embedding_key is None:
+ raise KeyError("Cannot find embedding key!")
+ if model_config["embed_split_hidden"]:
+ embed_concat_dim = 1
+ tok_emb_list = [states[i][embedding_key] for i in range(num_shards)]
+ else:
+ embed_concat_dim = 0
+ _, size_1 = states[0][embedding_key].shape
+ embdim_pertp = size_1 // num_shards
+ tok_emb_list = [
+ torch.concat(
+ [
+ states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)]
+ for tp in range(num_shards)
+ ],
+ dim=0,
+ )
+ for local_rank in range(num_shards)
+ ]
+ state_dict.update(
+ {
+ "model.norm.weight": states[0]["norm.weight"],
+ "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim),
+ "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
+ },
+ )
+
+ # save state_dict to hf format
+ shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
+ for shard_file, shard in shards.items():
+ llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
+ if index is not None:
+ # Save the index as well
+ llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py
index cd8eaff2..8370606b 100644
--- a/internlm/model/modules/mha.py
+++ b/internlm/model/modules/mha.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import inspect
import math
from typing import Callable, Dict, Optional
@@ -75,6 +76,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
qk_interleaved: Optional[bool] = True,
enable_qkv_fusion: bool = True,
+ out_bias: bool = True,
) -> None:
super().__init__()
self.layer_idx = layer_idx
@@ -83,6 +85,7 @@ def __init__(
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // num_heads
+ self.kv_dim = self.head_dim * num_heads # num_kv_heads equals to num_heads in MHA
self.enable_qkv_fusion = enable_qkv_fusion
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
@@ -116,8 +119,8 @@ def __init__(
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
- # output projection always have the bias (for now)
- self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=True, **factory_kwargs)
+ # output projection always have the bias (for now) (except for baichuan2 model)
+ self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs)
def register_checkpoint_compatibility_hooks(
self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None
@@ -355,6 +358,7 @@ def __init__(
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 2048,
+ head_dim: int = None,
bias: bool = False,
dropout: float = 0.0,
softmax_scale: float = None,
@@ -375,9 +379,15 @@ def __init__(
self.embed_dim = embed_dim
self.num_heads = num_heads
+
+ if head_dim:
+ self.head_dim = head_dim
+ q_dim = head_dim * num_heads
+ else:
+ self.head_dim = self.embed_dim // num_heads
+ q_dim = embed_dim
self.num_kv_heads = num_kv_heads
self.q_per_kv = num_heads // num_kv_heads
- self.head_dim = self.embed_dim // num_heads
self.kv_dim = self.head_dim * num_kv_heads
self.enable_qkv_fusion = enable_qkv_fusion
@@ -405,7 +415,7 @@ def __init__(
if enable_qkv_fusion:
self.wqkv = new_linear("wqkv", embed_dim, embed_dim + 2 * self.kv_dim, bias, **factory_kwargs)
else:
- self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs)
+ self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs)
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
@@ -416,7 +426,7 @@ def __init__(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx
)
- self.wo = new_linear("wo", embed_dim, embed_dim, bias, **factory_kwargs)
+ self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs)
def register_checkpoint_compatibility_hooks(
self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None
@@ -624,3 +634,337 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
# wo
return self.wo(rearrange(context, "b s h d -> b s (h d)"))
+
+
+try:
+ from flash_attn import flash_attn_func
+
+ # flash_attn >= v2.3.0
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+except (ModuleNotFoundError, ImportError):
+ _flash_supports_window_size = False
+
+
+class SWA(nn.Module):
+ """
+ sliding window attention
+
+ Args:
+ embed_dim (int): The dimention of hidden state.
+ num_heads (int): The number of attention heads.
+ process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
+ sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation.
+ bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
+ output projection. True by default.
+ dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
+ softmax_scale (float): The temperature to use for the softmax attention.
+ causal (boolean): Whether to apply causal attention mask. False by default.
+ layer_idx (int): The index of current layer. None by default.
+ rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
+ rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
+ XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
+ device (Optional[Union[str, torch.device]]): The device will be used.
+ dtype (Optional[torch.dtype]): The type of data.
+ rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
+ tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"],
+ "mtp" by default.
+
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ num_kv_heads: int,
+ qkv_bias: bool = True,
+ o_bias: bool = False,
+ max_position_embeddings: int = 2048,
+ dropout: float = 0.0,
+ softmax_scale: float = None,
+ causal: bool = False,
+ layer_idx: int = None,
+ use_dynamic_ntk_rope: bool = False,
+ rope_type: str = "normal",
+ rope_base: int = 10000,
+ rope_scaling_factor: float = 1.0,
+ rotary_emb_dim: int = 0,
+ rotary_emb_scale_base: int = 0,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ use_sliding_window: bool = False,
+ sliding_window: int = None,
+ tp_mode: str = "mtp",
+ qk_interleaved: Optional[bool] = True,
+ use_logn_attn: bool = False, # Qwen1
+ ) -> None:
+ assert embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads"
+ assert (not use_sliding_window) or (
+ sliding_window is not None
+ ), "Must set `sliding windows` size when `use_sliding_window` is True."
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+
+ self.head_dim = self.embed_dim // num_heads
+ self.num_kv_heads = num_kv_heads
+ self.kv_dim = self.head_dim * num_kv_heads
+ self.causal = causal
+ self.layer_idx = layer_idx
+ self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
+ self.rotary_emb_dim = rotary_emb_dim
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.dtype = dtype
+ self.tp_mode = tp_mode
+ self.rope_type = rope_type
+ self.use_logn_attn = use_logn_attn
+ self.interleaved = qk_interleaved
+
+ factory_kwargs = {"device": device, "dtype": dtype}
+
+ assert self.use_dynamic_ntk_rope is False, "Not support dynamic ntk rope yet."
+ assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads"
+
+ if self.rotary_emb_dim > 0:
+ self.rotary_emb = new_rotary_embedding(
+ self.rotary_emb_dim,
+ base=rope_base,
+ scale_base=rotary_emb_scale_base,
+ device=device,
+ max_position_embeddings=max_position_embeddings,
+ scaling_factor=rope_scaling_factor,
+ rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native",
+ )
+
+ # notice here should change bias=True
+ self.wq = new_linear(
+ "wq",
+ embed_dim,
+ embed_dim,
+ qkv_bias,
+ **factory_kwargs,
+ )
+ self.wk = new_linear(
+ "wk",
+ embed_dim,
+ self.kv_dim,
+ qkv_bias,
+ **factory_kwargs,
+ )
+ self.wv = new_linear(
+ "wv",
+ embed_dim,
+ self.kv_dim,
+ qkv_bias,
+ **factory_kwargs,
+ )
+
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
+
+ self.inner_cross_attn_causal = causal
+ self.inner_cross_attn_softmax_scale = softmax_scale
+ self.inner_cross_attn_dropout = dropout
+
+ self.wo = new_linear(
+ "wo",
+ embed_dim,
+ embed_dim,
+ o_bias,
+ **factory_kwargs,
+ )
+
+ def forward(self, x, inference_params=None, **kwargs):
+ if inference_params is None:
+ return self._training(x=x, **kwargs)
+ else:
+ return self._inference(x=x, inference_params=inference_params, **kwargs)
+
+ def _training(self, x, **kwargs):
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
+ q = rearrange(q, "b t (h d) -> b t h d", d=self.head_dim)
+ k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim)
+ v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim)
+
+ kv_seq_len = k.size(0)
+ use_window_circumstance = (
+ _flash_supports_window_size
+ and self.use_sliding_window
+ and self.sliding_window
+ and kv_seq_len > self.sliding_window
+ )
+
+ kwargs = _convert_cu_seqlens_for_qksplited(kwargs)
+
+ # rotary embedding
+ if self.rotary_emb_dim > 0:
+ indexes = kwargs.pop("indexes", 0)
+ max_seqlen_q = kwargs.get("max_seqlen_q", None)
+ max_seqlen_k = kwargs.get("max_seqlen_k", None)
+
+ q = self.rotary_emb(
+ q, offsets=indexes, max_seqlen=max_seqlen_q, cache_type="query", interleaved=self.interleaved
+ )
+ k = self.rotary_emb(
+ k, offsets=indexes, max_seqlen=max_seqlen_k, cache_type="key", interleaved=self.interleaved
+ )
+
+ kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
+
+ if use_window_circumstance:
+ kwargs["window_size"] = (self.sliding_window, 0)
+
+ # self attention
+ context = self.inner_attn(q, kv, **kwargs)
+
+ # wo
+ return self.wo(rearrange(context, "b s h d -> b s (h d)"))
+
+ def _convert_unpacked_qkv_to_packed(
+ self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor
+ ):
+ cu_seqlens = torch.concat(
+ [
+ torch.tensor([0], dtype=torch.int32, device=attention_mask.device),
+ attention_mask.sum(dim=-1).to(dtype=torch.int32),
+ ],
+ dim=0,
+ ).cumsum(dim=0, dtype=torch.int32)
+
+ cu_seqlens_q = cu_seqlens
+ cu_seqlens_k = cu_seqlens
+
+ max_seqlen_q = attention_mask.shape[-1]
+ max_seqlen_k = attention_mask.shape[-1]
+
+ q_packed = (
+ q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
+ )
+ kv_packed = (
+ kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
+ .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
+ .unsqueeze(0)
+ )
+
+ return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
+
+ def _inference(self, x, inference_params=None, **kwargs): # pylint: disable=W0613
+ assert inference_params is not None, "inference_params is required for inference"
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
+ attention_mask = inference_params.attention_mask
+ sequence_len_offset = inference_params.sequence_len_offset
+ window_size = inference_params.window_size
+
+ bsz = x.shape[0]
+
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
+ q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
+ k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
+ v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
+
+ kv_seq_len = k.size(0)
+ use_window_circumstance = (
+ _flash_supports_window_size
+ and self.use_sliding_window
+ and self.sliding_window
+ and kv_seq_len > self.sliding_window
+ )
+
+ assert self.rotary_emb_dim > 0
+ if attention_mask is None:
+ raise NotImplementedError(
+ "You should make sure you are aware that you are changing the method of generating."
+ "According to your generation function instead of inference/seq_generator_module.py, "
+ "You may implement here for normal running."
+ )
+ else:
+ if inference_params.sequence_len_offset == 0:
+ q = self.rotary_emb(
+ q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask
+ )
+ k = self.rotary_emb(
+ k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask
+ )
+ else:
+ empties = attention_mask[..., -1].sum(dim=-1)
+ indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties
+ indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties
+ q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved)
+ k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved)
+
+ kv = torch.stack([k, v], dim=2)
+
+ if window_size is None or window_size > sequence_len_offset:
+ kv = update_kv_cache(kv, inference_params, self.layer_idx)
+ else: # window_size <= sequence_len_offset
+ assert kv.size(1) == 1, "update kv length more than 1"
+
+ inference_params.key_value_memory_dict[self.layer_idx][
+ :, inference_params.keep_first : inference_params.window_size - 1, ...
+ ] = inference_params.key_value_memory_dict[self.layer_idx][
+ :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ...
+ ].clone()
+ inference_params.real_sequence_len_offset = inference_params.sequence_len_offset
+ inference_params.sequence_len_offset = inference_params.window_size - 1
+
+ kv = update_kv_cache(kv, inference_params, self.layer_idx)
+
+ inference_params.sequence_len_offset = inference_params.real_sequence_len_offset
+
+ # When using FP16, there is a high probability of NAN in the KV.
+ # Since NAN cannot be removed by multiplying with and 0, it needs
+ # to be removed manually here.
+ kv = torch.where(torch.isnan(kv), 0, kv)
+
+ # attention
+ if attention_mask is None:
+ context = self.inner_cross_attn(q, kv)
+ else:
+ if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
+ attn_mask = attention_mask[:, None, ...]
+ attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask)
+ attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1)
+
+ if use_window_circumstance:
+ output = self.inner_attn(
+ *self._convert_unpacked_qkv_to_packed(q, kv, bsz, attn_mask4flsh),
+ window_size=(self.sliding_window, 0),
+ )
+ else:
+ output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, bsz, attn_mask4flsh))
+ output = output.to(x.dtype)
+
+ context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)
+
+ else:
+ attn_mask = attention_mask[:, -1, :].view(bsz, 1, 1, -1)
+ if window_size is not None and window_size <= sequence_len_offset:
+ attn_mask = torch.concat(
+ [
+ attn_mask[..., : inference_params.keep_first],
+ attn_mask[..., -(window_size - inference_params.keep_first) :],
+ ],
+ dim=-1,
+ )
+
+ k, v = torch.chunk(kv, 2, dim=2)
+ k = k.squeeze(2)
+ v = v.squeeze(2)
+ sp = k.shape
+ expansion = q.size(2) // k.size(2)
+ scores = torch.einsum(
+ "blhd,bnhd->bhln",
+ q,
+ k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
+ ) / math.sqrt(q.size(-1))
+ scores = scores.masked_fill(attn_mask, -65000.0)
+ scores = F.softmax(scores, dim=-1) # bsz x h x L x L
+ context = torch.einsum(
+ "bhmn,bnhd->bmhd",
+ scores,
+ v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
+ )
+
+ # wo
+ return self.wo(rearrange(context, "b s h d -> b s (h d)"))
diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py
index 897e1363..b836ff3d 100644
--- a/internlm/model/modules/mlp.py
+++ b/internlm/model/modules/mlp.py
@@ -7,8 +7,9 @@
from torch import nn
from internlm.model.modules.linear import new_linear
-from internlm.model.modules.utils import Silu
+from internlm.model.modules.utils import Gelu, Silu
from internlm.utils.logger import get_logger
+from internlm.utils.utils import ActivationType
logger = get_logger(__file__)
@@ -71,10 +72,13 @@ def __init__(
):
super().__init__()
- # TODO: support gelu...
- assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}"
+ assert activation_type in (
+ ActivationType.swiglu.name,
+ ActivationType.gelu.name,
+ ), f"Unsupported activation type: {activation_type}"
self.mlp_layer_fusion = mlp_layer_fusion
+ self.activation_type = activation_type
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
@@ -98,7 +102,12 @@ def forward(self, x):
else:
fussed_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)
- out = self.w2(Silu(w1_o, w3_o))
+
+ if self.activation_type is ActivationType.swiglu.name:
+ out = self.w2(Silu(w1_o, w3_o))
+ else:
+ out = self.w2(Gelu(w1_o, w3_o))
+
return out
diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py
index b94cdd43..2a9700f8 100644
--- a/internlm/model/modules/norm.py
+++ b/internlm/model/modules/norm.py
@@ -2,6 +2,7 @@
layer norm modules
"""
+import inspect
from typing import List, Union
import torch
@@ -12,8 +13,12 @@
Shape = Union[int, List[int], torch.Size]
-def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5):
+def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False):
if norm_type == "rmsnorm":
- return RMSNorm(normalized_shape, eps)
+ rmsnorm_params = inspect.signature(RMSNorm).parameters
+ if "add_unit_offset" in rmsnorm_params:
+ return RMSNorm(normalized_shape, eps, add_unit_offset)
+ else:
+ return RMSNorm(normalized_shape, eps)
else: # default: layernorm
return nn.LayerNorm(normalized_shape, eps)
diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py
index dd86cb1c..bf1ae048 100644
--- a/internlm/model/modules/utils.py
+++ b/internlm/model/modules/utils.py
@@ -20,7 +20,12 @@ def Silu(w1_o, w2_o):
return F.silu(w1_o) * w2_o
+def Gelu(w1_o, w2_o):
+ return F.gelu(w1_o) * w2_o
+
+
Silu = torch.jit.script(Silu)
+Gelu = torch.jit.script(Gelu)
def update_kv_cache(kv, inference_params, layer_idx):
diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py
index 3cd43dab..34e7c007 100644
--- a/internlm/model/ops/norm.py
+++ b/internlm/model/ops/norm.py
@@ -35,7 +35,7 @@
torchnpu_rmsnorm_impl = False
-def manual_rms_norm(my_input, weight, normalized_shape, eps):
+def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
@@ -48,13 +48,16 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps):
if weight.dtype in [torch.float16, torch.bfloat16]:
my_input = my_input.to(weight.dtype)
- return weight * my_input
+ if add_unit_offset:
+ return (1 + weight) * my_input
+ else:
+ return weight * my_input
class _RMSNorm(torch.nn.Module):
"""A generic module for RMS normalization."""
- def __init__(self, normalized_shape, eps=1e-5):
+ def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
@@ -62,18 +65,22 @@ def __init__(self, normalized_shape, eps=1e-5):
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.empty(*normalized_shape))
+ self.add_unit_offset = add_unit_offset
self.reset_parameters()
def forward(self, _input: torch.Tensor):
if apex_rmsnorm_impl:
_norm_func = mixed_dtype_fused_rms_norm_affine
+ return _norm_func(_input, self.weight, self.normalized_shape, self.eps)
else:
_norm_func = manual_rms_norm
-
- return _norm_func(_input, self.weight, self.normalized_shape, self.eps)
+ return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset)
def reset_parameters(self):
- init.ones_(self.weight)
+ if self.add_unit_offset:
+ init.zeros_(self.weight)
+ else:
+ init.ones_(self.weight)
def extra_repr(self):
return f"{self.normalized_shape}, eps={self.eps}, "
diff --git a/internlm/model/registry.py b/internlm/model/registry.py
index a1921ab6..c923ec20 100644
--- a/internlm/model/registry.py
+++ b/internlm/model/registry.py
@@ -3,11 +3,14 @@
from typing import Callable
+from internlm.model.modeling_baichuan2 import Baichuan2
+from internlm.model.modeling_gemma import Gemma
from internlm.model.modeling_internlm import InternLM1
from internlm.model.modeling_internlm2 import InternLM2
from internlm.model.modeling_llama import Llama2
from internlm.model.modeling_llava import Llava
from internlm.model.modeling_moe import Internlm1MoE
+from internlm.model.modeling_qwen2 import Qwen2
from internlm.utils.common import SingletonMeta
from internlm.utils.utils import ModelType
@@ -83,6 +86,9 @@ def register_model_initializer() -> None:
model_initializer.register_module(ModelType.LLAMA2.name, Llama2)
model_initializer.register_module(ModelType.INTERNLM_MoE.name, Internlm1MoE)
model_initializer.register_module(ModelType.LLAVA.name, Llava)
+ model_initializer.register_module(ModelType.QWEN2.name, Qwen2)
+ model_initializer.register_module(ModelType.BAICHUAN2.name, Baichuan2)
+ model_initializer.register_module(ModelType.GEMMA.name, Gemma)
register_model_initializer()
diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py
index 03dee6df..ca6b3215 100644
--- a/internlm/utils/utils.py
+++ b/internlm/utils/utils.py
@@ -47,6 +47,9 @@ class ModelType(Enum):
LLAMA2 = 3
INTERNLM_MoE = 4
LLAVA = 5
+ QWEN2 = 6
+ BAICHUAN2 = 7
+ GEMMA = 8
class DataType(Enum):
@@ -61,6 +64,11 @@ class TensorParallelMode(Enum):
isp = 4
+class ActivationType(Enum):
+ swiglu = 1
+ gelu = 2
+
+
def check_attention_argument(*args, **kwargs) -> str:
# self, qkv, ...
# self, q, kv, ....
|