Skip to content

Commit

Permalink
Merge pull request InternLM#1 from zhhsplendid/hwh/feat/hetero
Browse files Browse the repository at this point in the history
Add hetero feat to InternEvo
  • Loading branch information
zhhsplendid authored Nov 5, 2024
2 parents 62ff685 + 67fe1ed commit 24f8dd9
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 3 deletions.
35 changes: 34 additions & 1 deletion internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,46 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context

import math
import os
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps

import torch.distributed as dist

from internlm.utils.timeout import LLM_NCCL_TIMEOUT

hetero = os.environ.get("HETERO", "0") == "1"


if hetero is True:
from internlm.accelerator import get_accelerator
internlm_accelerator = get_accelerator()


def wrapper_new_group(fn):

@wraps(fn)
def decorated(*args, **kwargs):
no_backend = True
be = internlm_accelerator._communication_backend_name
if args:
args_new = list(args)
for idx, arg in enumerate(args_new):
if type(arg) == str and (be in arg or "gloo" in arg):
no_backend = False
args = args_new
if kwargs:
backend = kwargs.get('backend', None)
if isinstance(backend, str) and (be in backend or "gloo" in backend):
no_backend = False
if no_backend is True:
kwargs['backend'] = internlm_accelerator._communication_backend_name
return fn(*args, **kwargs)

return decorated
dist.new_group = wrapper_new_group(dist.new_group)


# parallel modes
class ParallelMode(Enum):
Expand Down Expand Up @@ -204,7 +237,7 @@ def init_dist_group(self, use_cpu: bool = False):
for i in range(self.num_pp_group):
ranks = [i + j * self.num_pp_group for j in range(self.pipeline_parallel_size)]
pipe_group_size = len(ranks)
pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
pipe_group = dist.new_group(ranks, backend="gloo" if hetero is True else None, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
Expand Down
6 changes: 6 additions & 0 deletions internlm/core/scheduler/comm/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ def _communicate(
return tensor_recv_prev, tensor_recv_next


from internlm.core.context.process_group_initializer import hetero
if hetero is True:
from .p2p_gloo import _communicate_by_gloo
_communicate = _communicate_by_gloo


def recv_forward(
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
Expand Down
169 changes: 169 additions & 0 deletions internlm/core/scheduler/comm/p2p_gloo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import List, Tuple, Union

import torch
import torch.distributed as dist

from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc

from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks

TensorShape = Union[torch.Size, List[int], Tuple[int]]
internlm_accelerator = get_accelerator()


def create_recv_buffer_with_shapes_by_gloo(recv_shapes, dtype, scatter_gather_tensors):
from .p2p import _get_tensor_shape
if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device="cpu", dtype=dtype)
return buffer_recv, recv_split
buffer_recv = []
for recv_shape in recv_shapes:
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device="cpu", dtype=dtype)
buffer_recv.append(tensor_recv)
return buffer_recv, recv_split


def process_object_to_send_by_gloo(object_send, scatter_gather_tensors):
from .p2p import _get_tensor_shape
if isinstance(object_send, torch.Tensor):
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
if send_split:
object_send = split_tensor_into_1d_equal_chunks(object_send)
object_send = object_send.cpu()
return object_send

object_send_list = []
for tensor_send in object_send:
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
if send_split:
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send).cpu())
else:
object_send_list.append(tensor_send.cpu())
object_send = tuple(object_send_list)

return object_send

def _communicate_by_gloo(
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank
(no tensor sent if set to None).
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank
(no tensor sent if set to None).
recv_prev (bool): boolean for whether tensor should be received from
previous rank.
recv_next (bool): boolean for whether tensor should be received from
next rank.
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
from the previous stage, defualts to None.
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
from the next stage, defualts to None.
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
next_rank (int): the rank of the next pipeline stage, defualts to None,
dtype (torch.dtype): data type of intermediate buffers, defaults to None
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
Returns:
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
"""

from .p2p import filling_ops_queue

# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None

if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes_by_gloo(
recv_prev_shape, dtype, scatter_gather_tensors
)

if recv_next:
assert recv_next_shape is not None
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes_by_gloo(
recv_next_shape, dtype, scatter_gather_tensors
)

if object_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)

if object_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)

if object_send_prev is not None:
object_send_prev = process_object_to_send_by_gloo(object_send_prev, scatter_gather_tensors)

if object_send_next is not None:
object_send_next = process_object_to_send_by_gloo(object_send_next, scatter_gather_tensors)

ops = []
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()

if tensor_recv_prev is not None:
tensor_recv_prev = tensor_recv_prev.cuda()
if tensor_recv_next is not None:
tensor_recv_next = tensor_recv_next.cuda()
if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = (
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
tensor_recv_prev[index] = tensor_recv_prev[index].cuda()

if recv_next and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = (
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
tensor_recv_next[index] = tensor_recv_next[index].cuda()


return tensor_recv_prev, tensor_recv_next
4 changes: 2 additions & 2 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.core.context.process_group_initializer import ParallelMode, hetero
from internlm.utils.common import get_master_node
from internlm.utils.gputest import warmup_process_group
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -714,7 +714,7 @@ def initialize_distributed_env(
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
backend = internlm_accelerator._communication_backend_name
backend = "gloo" if hetero is True else internlm_accelerator._communication_backend_name

if launcher == "torch":
launch_from_torch(config=config, seed=seed, backend=backend)
Expand Down

0 comments on commit 24f8dd9

Please sign in to comment.