Skip to content

Commit

Permalink
initial refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Aug 22, 2024
1 parent 81fb735 commit 4446ecf
Show file tree
Hide file tree
Showing 21 changed files with 348 additions and 358 deletions.
8 changes: 1 addition & 7 deletions doc/code-docs/source/checkpoint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CheckpointManager
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
load_ckpt_info=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"),
load_ckpt_info=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internevo"),
auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_info'.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3, volc and oss2 ckpt)
Expand Down Expand Up @@ -95,7 +95,6 @@ load_ckpt_info 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type
- ``internevo``:internevo约定的checkpoint存储格式。
- ``llama``:llama约定的checkpoint存储格式。
- ``hf_llama``:huggingface llama约定的checkpoint存储格式。
- ``hf_model``:适用于加载huggingface所有模型的checkpoint存储格式。

下面给出两个例子:

Expand All @@ -107,10 +106,6 @@ load_ckpt_info 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type
# 从文件存储相对路径 ckpt_model 中加载所有的状态,适合断点续训的场景
load_ckpt_info = dict(path="local:ckpt_model", content=("all",), ckpt_type="internevo")
# 从 huggingface 下载指定模型,加载checkpoint
load_ckpt_info = dict(path="internlm/internlm-7b", content=("model",), ckpt_type="hf_model")
.. _asyncupload:

异步上传
Expand Down Expand Up @@ -179,4 +174,3 @@ config.ckpt 中相关的参数:
# 如果存入的step>0,则任务会在存储ckpt后自动退出
# 如果存入的step<0,则任务会在存储ckpt后会继续训练
echo "999" > ./llm_alter/1006_pr.log
10 changes: 5 additions & 5 deletions doc/en/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ If you are using the HuggingFace datasets for on-the-fly streaming load and toke
```python
TRAIN_FOLDER = "roneneldan/TinyStories"
data = dict(
type="hf",
type="streaming",
tokenizer_path="internlm/internlm-7b",
)
```
Expand Down Expand Up @@ -117,7 +117,7 @@ ckpt = dict(
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
Expand Down Expand Up @@ -295,8 +295,8 @@ ckpt = dict(
# When resuming training from a breakpoint,:
# (1) 'path' is the path of the loaded checkpoint.
# (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internevo"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
)
```

Expand Down Expand Up @@ -427,4 +427,4 @@ model = dict(
Regarding the principle of Dyanmic NTK, please refer to

1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
2. https://kexue.fm/archives/9675
2. https://kexue.fm/archives/9675
14 changes: 6 additions & 8 deletions doc/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
```python
TRAIN_FOLDER = "roneneldan/TinyStories"
data = dict(
type="hf",
type="streaming",
tokenizer_path="internlm/internlm-7b",
)
```
Expand Down Expand Up @@ -106,7 +106,7 @@ ckpt = dict(
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama", "hf_model".
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama".
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
Expand Down Expand Up @@ -327,7 +327,7 @@ train_folder设置为huggingface上可以通过load_dataset直接下载的数据
TRAIN_FOLDER = "roneneldan/TinyStories"
SEQ_LEN = 2048
data = dict(
type="hf",
type="streaming",
tokenizer_path="internlm/internlm-7b",
seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048
micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1
Expand All @@ -353,10 +353,8 @@ ckpt = dict(
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf
# 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
# content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all"
# ckpt_type 表示加载的模型类型,目前支持: "internevo", "llama", "hf_llama", "hf_model"
# 其中,"hf_model"类型表示从huggingface上下载模型加载ckpt,MODEL_ONLY_FOLDER需要设置为可以
# 通过AutoModel直接加载的模型路径,如:"internlm/internlm-7b"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# ckpt_type 表示加载的模型类型,目前支持: "internevo", "llama", "hf_llama"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
# 'auto_resume' 旨在在遇到由硬件故障引起的训练中断/挂起时,自动从 'save_ckpt_folder' 加载最新的检查点,
# 使用调度系统(例如 k8s/slurm)在训练重启时自动重启机制。
# 请注意,如果未设置 auto_resume(其默认值为 True),它将不会默认加载 load_ckpt_info 中指定的检查点路径。
Expand Down Expand Up @@ -510,4 +508,4 @@ generation = dict(

关于 Dyanmic NTK 的原理,详细请参考
1. [dynamically_scaled_rope_further_increases](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases)
2. [https://kexue.fm/archives/9675](https://kexue.fm/archives/9675)
2. [https://kexue.fm/archives/9675](https://kexue.fm/archives/9675)
15 changes: 0 additions & 15 deletions internlm/checkpoint/load_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load
from internlm.utils.utils import ModelType
from transformers import AutoModelForCausalLM

logger = get_logger(__file__)
internlm_accelerator = get_accelerator()
Expand Down Expand Up @@ -300,22 +299,8 @@ def load_internlm_with_dynamic_parallel_size(folder, model):
)


def load_hf_model_pretrained_weights(folder, model):
"""NOTE: when loading huggingface's model pretrained weights, you should set `adapt_hf=True` in your config."""
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")

pretrained_model = AutoModelForCausalLM.from_pretrained(folder, trust_remote_code=True)
model.load_state_dict(pretrained_model.state_dict(), strict=False)

if gpc.is_rank_for_log():
logger.info("Pretrained weights loaded successfully")


LOAD_FUNC_DICT = {
"llama": load_llama_pretrained_weights,
"hf_llama": load_hf_llama_pretrained_weights,
"internlm_test": load_internlm_with_dynamic_parallel_size,
"hf_model": load_hf_model_pretrained_weights,
}
23 changes: 11 additions & 12 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,41 +892,40 @@ def _(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs) -> torc
return context


def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, float], nn.Module]:
def auto_wrap_distributed_attention(attn_impl: nn.Module) -> Callable[[bool, Any, float], nn.Module]:
"""
Wrap a local attention module to a distributed one, which will be used in the ISP parallelism.
"""

# should we impl distributed attention as a metaclass?
def _attetion_constructor(
local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0
) -> nn.Module:
def _attetion_constructor(attn_impl: type, causal=False, softmax_scale=None, attention_dropout=0.0) -> nn.Module:
tp_mode = gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name)

if tp_mode != TensorParallelMode.isp.name:
return local_attn_cls(causal, softmax_scale, attention_dropout)
return attn_impl(causal, softmax_scale, attention_dropout)
else:
return DistributedAttention(
local_attention=local_attn_cls(causal, softmax_scale, attention_dropout),
local_attention=attn_impl(causal, softmax_scale, attention_dropout),
sequence_process_group=gpc.get_group(ParallelMode.TENSOR),
)

return partial(_attetion_constructor, local_attn_cls=cls)
return partial(_attetion_constructor, attn_impl=attn_impl)


def auto_wrap_func_distributed_attention(func: Callable) -> Callable[..., Callable]:
def auto_wrap_func_distributed_attention(attn_impl: Callable) -> Callable[..., Callable]:
"""
Wrap a local attention function to a distributed one, which will be used in the ISP parallelism.
"""

def _attention_func_constructor(*args, local_attn_func=None, **kwargs) -> Callable:
# should we impl distributed attention as a metaclass?
def _attetion_constructor(*args, attn_impl: type, **kwargs) -> Callable:
tp_mode = gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name)

if tp_mode != TensorParallelMode.isp.name:
return local_attn_func(*args, **kwargs)
return attn_impl(*args, **kwargs)
else:
return DistributedAttention(
local_attention=local_attn_func, sequence_process_group=gpc.get_group(ParallelMode.TENSOR)
local_attention=attn_impl, sequence_process_group=gpc.get_group(ParallelMode.TENSOR)
)(*args, **kwargs)

return partial(_attention_func_constructor, local_attn_func=func)
return partial(_attetion_constructor, attn_impl=attn_impl)
8 changes: 1 addition & 7 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import _split
from internlm.utils.logger import get_logger
from internlm.utils.utils import ModelType, TensorParallelMode
from internlm.utils.utils import TensorParallelMode

logger = get_logger(__file__)

Expand All @@ -31,12 +31,6 @@ def split_data_sequence_parallel(data, label):
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim)
if (
gpc.config.model_type == ModelType.HF.name
and "position_ids" in data
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_indexes_seq_dim)

data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim)
label = _split(label, ParallelMode.TENSOR, dim=_seq_dim)
Expand Down
Loading

0 comments on commit 4446ecf

Please sign in to comment.