diff --git a/internlm/data/tokenized/dummy_dataset.py b/internlm/data/tokenized/dummy_dataset.py index dcb6c027..1e64e00a 100644 --- a/internlm/data/tokenized/dummy_dataset.py +++ b/internlm/data/tokenized/dummy_dataset.py @@ -4,7 +4,7 @@ import numpy as np from torch.utils.data import Dataset -# from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc class RandomDataset(Dataset): @@ -30,7 +30,7 @@ def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False) while len(d) < max_len: r *= 2 d = list(range(n)) * r - # r = r % gpc.config.model.vocab_size + r = r % gpc.config.model.vocab_size d = [n, r] + d d = d[:max_len] data.append(d) diff --git a/internlm/model/ops/ring_flash_attn/utils.py b/internlm/model/ops/ring_flash_attn/utils.py index c15f074a..1875a903 100644 --- a/internlm/model/ops/ring_flash_attn/utils.py +++ b/internlm/model/ops/ring_flash_attn/utils.py @@ -1,11 +1,13 @@ # Adapted from https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/utils.py +from functools import partial from typing import Optional, Tuple import torch import torch.distributed as dist import torch.nn.functional as F + __all__ = ["update_out_and_lse", "RingComm"] @@ -59,9 +61,11 @@ class RingComm: def __init__(self, process_group: dist.ProcessGroup): self._process_group = process_group self._ops = [] + self._funcs = [] self.rank = dist.get_rank(self._process_group) self.world_size = dist.get_world_size(self._process_group) self._reqs = None + self._handles = None self.send_rank = (self.rank + 1) % self.world_size self.recv_rank = (self.rank - 1) % self.world_size @@ -69,29 +73,59 @@ def __init__(self, process_group: dist.ProcessGroup): if process_group is not None: self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + self.global_rank = dist.get_rank() # print(f'rank:{self.rank},send_rank:{self.send_rank},recv_rank:{self.recv_rank}') def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.send_rank == self.recv_rank == self.global_rank: + return to_send + if recv_tensor is None: res = torch.empty_like(to_send) else: res = recv_tensor - send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) + # send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + # recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + # self._ops.append(send_op) + # self._ops.append(recv_op) + + send_func = partial(dist.isend, tensor=to_send, dst=self.send_rank, group=self._process_group) + recv_func = partial(dist.irecv, tensor=res, src=self.recv_rank, group=self._process_group) + + if self.rank % 2 == 0: + self._funcs.append(send_func) + self._funcs.append(recv_func) + else: + self._funcs.append(recv_func) + self._funcs.append(send_func) + return res def commit(self): - if self._reqs is not None: + # if self._reqs is not None: + # raise RuntimeError("commit called twice") + # self._reqs = dist.batch_isend_irecv(self._ops) + + if self._handles is not None: raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) + self._handles = [] + + for _func in self._funcs: + _handle = _func() + self._handles.append(_handle) def wait(self): - if self._reqs is None: + # if self._reqs is None: + # raise RuntimeError("wait called before commit") + # for req in self._reqs: + # req.wait() + # self._reqs = None + # self._ops = [] + + if self._handles is None: raise RuntimeError("wait called before commit") - for req in self._reqs: - req.wait() - self._reqs = None - self._ops = [] + for _handle in self._handles: + _handle.wait() + self._handles = None + self._funcs = []