Skip to content

Commit

Permalink
feat(pipeline): Zero Bubble V Shape Memory Efficient Editon (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Oct 29, 2024
1 parent 2bae28f commit 330c03a
Show file tree
Hide file tree
Showing 26 changed files with 1,482 additions and 522 deletions.
3 changes: 2 additions & 1 deletion configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,15 @@
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"),
weight=dict(size=2, overlap=True),
)

Expand Down
3 changes: 1 addition & 2 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,14 @@
1. size: int, the size of pipeline parallel (Default is 1F1B).
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
3. zero_bubble: bool, enable/disable zero bubble pipeline parallelism (ZB-H1), defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True, zero_bubble=False),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True),
)

Expand Down
8 changes: 4 additions & 4 deletions doc/code-docs/locales/en/LC_MESSAGES/parallel.po
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,8 @@ msgstr ""
msgid "返回类型"
msgstr "Return type"

#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:19
#: internlm.core.scheduler.pipeline_scheduler.PipelineScheduler.forward_backward_step:19
#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:19
#: internlm.core.scheduler.pipeline_scheduler_1f1b.PipelineScheduler.forward_backward_step:19
#: of
msgid "Tuple[:class:`torch.Tensor`]"
msgstr ""
Expand All @@ -591,11 +591,11 @@ msgstr ""
"To use interleaved pipeline scheduler, users need to set "
"``model.num_chunks > 1`` in the config file."

#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler:1 of
#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler:1 of
msgid "Interleaved Pipeline Scheduler."
msgstr ""

#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:1
#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:1
#: of
msgid ""
"Run interleaved 1F1B schedule (model split into model chunks), with "
Expand Down
4 changes: 2 additions & 2 deletions doc/code-docs/source/parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ InternEvo 在流水线并行中使用 `1F1B <https://arxiv.org/pdf/2104.04473.pd
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
如果要使用非交错式调度, 需要设置 ``model.num_chunks = 1`` 。

.. autoclass:: internlm.core.scheduler.pipeline_scheduler.PipelineScheduler
.. autoclass:: internlm.core.scheduler.pipeline_scheduler_1f1b.PipelineScheduler
:members:

交错式流水线调度
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
如果要使用交错式调度, 需要设置 ``model.num_chunks > 1`` 。

.. autoclass:: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler
.. autoclass:: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler
:members:

值得注意的是,在使用交错式流水线调度器时可启用通信优化功能,即在 1F1B 阶段启用异步通信,以充分利用上行/下行带宽并实现通信与计算重叠。
Expand Down
3 changes: 2 additions & 1 deletion doc/en/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ The system code file structure is shown below:
│ │ │ └── process_group_initializer.py
│ │ ├── scheduler # Scheduling module, which manages schedulers for parallel training, including non-pipeline and pipeline parallel schedulers
│ │ │ ├── no_pipeline_scheduler.py
│ │ │ └── pipeline_scheduler.py
│ │ │ ├── pipeline_scheduler_1f1b.py
│ │ │ └── pipeline_scheduler_zb.py
│ │ ├── engine.py # Responsible for managing the training and evaluation process of the model
│ │ └── trainer.py # Responsible for managing the training engine and scheduler
│ ├── data # Data module, responsible for managing dataset generation and processing
Expand Down
3 changes: 2 additions & 1 deletion doc/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
│ │ │ └── process_group_initializer.py
│ │ ├── scheduler # 调度模块,管理并行训练的调度器,包括非流水线并行调度器和流水线并行调度器
│ │ │ ├── no_pipeline_scheduler.py
│ │ │ └── pipeline_scheduler.py
│ │ │ ├── pipeline_scheduler_1f1b.py
│ │ │ └── pipeline_scheduler_zb.py
│ │ ├── engine.py # 负责管理模型的训练和评估过程
│ │ └── trainer.py # 负责管理训练引擎和调度器
│ ├── data # 数据模块,负责管理数据集生成和处理
Expand Down
18 changes: 15 additions & 3 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(self):
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.is_evaluating = False
self.v_shape = False

@property
def config(self):
Expand Down Expand Up @@ -292,8 +293,13 @@ def is_rank_for_log(self):
and self.is_first_rank(ParallelMode.WEIGHT)
and self.is_first_rank(ParallelMode.DATA)
and self.is_first_rank(ParallelMode.WEIGHT_DATA)
and self.is_last_rank(ParallelMode.PIPELINE)
)

if not self.v_shape:
is_log_rank = is_log_rank and self.is_last_rank(ParallelMode.PIPELINE)
else:
is_log_rank = is_log_rank and self.is_first_rank(ParallelMode.PIPELINE)

return is_log_rank

def is_last_rank(self, parallel_mode: ParallelMode):
Expand Down Expand Up @@ -327,11 +333,17 @@ def is_pipeline_last_stage(self, ignore_virtual=False):
and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1
):
return False
return self.is_last_rank(ParallelMode.PIPELINE)
if not self.v_shape:
return self.is_last_rank(ParallelMode.PIPELINE)
else:
return self.is_first_rank(ParallelMode.PIPELINE)

def is_no_pp_or_last_stage(self):
# NOTICE!!!, this will ignore virutal stage
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
if not self.v_shape:
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
else:
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_first_rank(ParallelMode.PIPELINE)

def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
Expand Down
6 changes: 6 additions & 0 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,12 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
if self._isp_communicator and self._isp_communicator.overlap:
self._zero_optim.accumulate_left_grads_after_backward()

if (
getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() in ["ZBV", "ZBH1"]
and not self._zero_optim.skip_grad_reduce
):
self._zero_optim.reduce_left_grads_after_backward()

def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613
pass

Expand Down
14 changes: 10 additions & 4 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,16 @@ def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: i
if chunk_size == 0:
raise ValueError("Some nodes in Pipeline have no requests")

for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
if getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() == "ZBV" and idx == 1:
for p in range(pipeline_parallel_size - 1, -1, -1):
st = base_idx
base_idx += chunk_size + ((pipeline_parallel_size - p - 1) >= left)
parts[p].append((st, base_idx))
else:
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))

indexes = []
for _parts in parts:
Expand Down
7 changes: 4 additions & 3 deletions internlm/core/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base_scheduler import BaseScheduler
from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import (
InterleavedPipelineScheduler,
PipelineScheduler,
from .pipeline_scheduler_1f1b import InterleavedPipelineScheduler, PipelineScheduler
from .pipeline_scheduler_zb import (
ZeroBubblePipelineScheduler,
ZeroBubblePipelineVShapeScheduler,
)

__all__ = [
Expand All @@ -12,4 +12,5 @@
"InterleavedPipelineScheduler",
"PipelineScheduler",
"ZeroBubblePipelineScheduler",
"ZeroBubblePipelineVShapeScheduler",
]
6 changes: 2 additions & 4 deletions internlm/core/scheduler/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .p2p import (
AsynCommunicator,
fused_send_recv_tensor,
recv_backward,
recv_forward,
send_backward,
send_backward_and_recv_next_backward_async,
send_backward_recv_backward,
send_backward_recv_forward,
send_forward,
send_forward_and_recv_next_forward_async,
send_forward_backward_recv_forward_backward,
send_forward_recv_backward,
send_forward_recv_forward,
Expand All @@ -26,7 +25,6 @@
"recv_forward",
"send_obj_meta",
"recv_obj_meta",
"send_backward_and_recv_next_backward_async",
"send_forward_and_recv_next_forward_async",
"AsynCommunicator",
"fused_send_recv_tensor",
]
Loading

0 comments on commit 330c03a

Please sign in to comment.