Skip to content

Commit

Permalink
initial refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Aug 22, 2024
1 parent 81fb735 commit 24d960a
Show file tree
Hide file tree
Showing 18 changed files with 333 additions and 313 deletions.
2 changes: 1 addition & 1 deletion doc/en/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ If you are using the HuggingFace datasets for on-the-fly streaming load and toke
```python
TRAIN_FOLDER = "roneneldan/TinyStories"
data = dict(
type="hf",
type="streaming",
tokenizer_path="internlm/internlm-7b",
)
```
Expand Down
4 changes: 2 additions & 2 deletions doc/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
```python
TRAIN_FOLDER = "roneneldan/TinyStories"
data = dict(
type="hf",
type="streaming",
tokenizer_path="internlm/internlm-7b",
)
```
Expand Down Expand Up @@ -327,7 +327,7 @@ train_folder设置为huggingface上可以通过load_dataset直接下载的数据
TRAIN_FOLDER = "roneneldan/TinyStories"
SEQ_LEN = 2048
data = dict(
type="hf",
type="streaming",
tokenizer_path="internlm/internlm-7b",
seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048
micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1
Expand Down
23 changes: 11 additions & 12 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,41 +892,40 @@ def _(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs) -> torc
return context


def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, float], nn.Module]:
def auto_wrap_distributed_attention(attn_impl: nn.Module) -> Callable[[bool, Any, float], nn.Module]:
"""
Wrap a local attention module to a distributed one, which will be used in the ISP parallelism.
"""

# should we impl distributed attention as a metaclass?
def _attetion_constructor(
local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0
) -> nn.Module:
def _attetion_constructor(attn_impl: type, causal=False, softmax_scale=None, attention_dropout=0.0) -> nn.Module:
tp_mode = gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name)

if tp_mode != TensorParallelMode.isp.name:
return local_attn_cls(causal, softmax_scale, attention_dropout)
return attn_impl(causal, softmax_scale, attention_dropout)
else:
return DistributedAttention(
local_attention=local_attn_cls(causal, softmax_scale, attention_dropout),
local_attention=attn_impl(causal, softmax_scale, attention_dropout),
sequence_process_group=gpc.get_group(ParallelMode.TENSOR),
)

return partial(_attetion_constructor, local_attn_cls=cls)
return partial(_attetion_constructor, attn_impl=attn_impl)


def auto_wrap_func_distributed_attention(func: Callable) -> Callable[..., Callable]:
def auto_wrap_func_distributed_attention(attn_impl: Callable) -> Callable[..., Callable]:
"""
Wrap a local attention function to a distributed one, which will be used in the ISP parallelism.
"""

def _attention_func_constructor(*args, local_attn_func=None, **kwargs) -> Callable:
# should we impl distributed attention as a metaclass?
def _attetion_constructor(*args, attn_impl: type, **kwargs) -> Callable:
tp_mode = gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name)

if tp_mode != TensorParallelMode.isp.name:
return local_attn_func(*args, **kwargs)
return attn_impl(*args, **kwargs)
else:
return DistributedAttention(
local_attention=local_attn_func, sequence_process_group=gpc.get_group(ParallelMode.TENSOR)
local_attention=attn_impl, sequence_process_group=gpc.get_group(ParallelMode.TENSOR)
)(*args, **kwargs)

return partial(_attention_func_constructor, local_attn_func=func)
return partial(_attetion_constructor, attn_impl=attn_impl)
8 changes: 1 addition & 7 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import _split
from internlm.utils.logger import get_logger
from internlm.utils.utils import ModelType, TensorParallelMode
from internlm.utils.utils import TensorParallelMode

logger = get_logger(__file__)

Expand All @@ -31,12 +31,6 @@ def split_data_sequence_parallel(data, label):
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim)
if (
gpc.config.model_type == ModelType.HF.name
and "position_ids" in data
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_indexes_seq_dim)

data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim)
label = _split(label, ParallelMode.TENSOR, dim=_seq_dim)
Expand Down
Loading

0 comments on commit 24d960a

Please sign in to comment.