Skip to content

Commit

Permalink
test(ring_flash_attn): test ring flash attn use non batch p2p
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Dec 2, 2024
1 parent 4a6b453 commit 02dcacb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
4 changes: 2 additions & 2 deletions internlm/data/tokenized/dummy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from torch.utils.data import Dataset

# from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.parallel_context import global_context as gpc


class RandomDataset(Dataset):
Expand All @@ -30,7 +30,7 @@ def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False)
while len(d) < max_len:
r *= 2
d = list(range(n)) * r
# r = r % gpc.config.model.vocab_size
r = r % gpc.config.model.vocab_size
d = [n, r] + d
d = d[:max_len]
data.append(d)
Expand Down
56 changes: 45 additions & 11 deletions internlm/model/ops/ring_flash_attn/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Adapted from https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/utils.py

from functools import partial
from typing import Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F


__all__ = ["update_out_and_lse", "RingComm"]


Expand Down Expand Up @@ -59,39 +61,71 @@ class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self._funcs = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self._handles = None

self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size

if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
self.global_rank = dist.get_rank()
# print(f'rank:{self.rank},send_rank:{self.send_rank},recv_rank:{self.recv_rank}')

def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.send_rank == self.recv_rank == self.global_rank:
return to_send

if recv_tensor is None:
res = torch.empty_like(to_send)
else:
res = recv_tensor

send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
# send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
# recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
# self._ops.append(send_op)
# self._ops.append(recv_op)

send_func = partial(dist.isend, tensor=to_send, dst=self.send_rank, group=self._process_group)
recv_func = partial(dist.irecv, tensor=res, src=self.recv_rank, group=self._process_group)

if self.rank % 2 == 0:
self._funcs.append(send_func)
self._funcs.append(recv_func)
else:
self._funcs.append(recv_func)
self._funcs.append(send_func)

return res

def commit(self):
if self._reqs is not None:
# if self._reqs is not None:
# raise RuntimeError("commit called twice")
# self._reqs = dist.batch_isend_irecv(self._ops)

if self._handles is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
self._handles = []

for _func in self._funcs:
_handle = _func()
self._handles.append(_handle)

def wait(self):
if self._reqs is None:
# if self._reqs is None:
# raise RuntimeError("wait called before commit")
# for req in self._reqs:
# req.wait()
# self._reqs = None
# self._ops = []

if self._handles is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []
for _handle in self._handles:
_handle.wait()
self._handles = None
self._funcs = []

0 comments on commit 02dcacb

Please sign in to comment.