Skip to content

Commit

Permalink
fix README
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 6, 2024
1 parent a098e23 commit f4814a8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ParaAttention

Context parallel attention that works with torch.compile
Context parallel attention that works with torch.compile,
supporting both [**Ulysses Style**](https://arxiv.org/abs/2309.14509) and [**Ring Style**](https://arxiv.org/abs/2310.01889) parallelism.

This aims to provide:

Expand Down Expand Up @@ -31,7 +32,7 @@ pre-commit run --all-files

# Usage

## Run Unified Attention (Ulysses Style and Ring Style) with `torch.compile`
## Run Unified Attention (Hybird Ulysses Style and Ring Style) with `torch.compile`

``` python
import torch
Expand Down
36 changes: 17 additions & 19 deletions src/para_attn/para_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,25 +131,6 @@ def forward(
scale,
mesh,
):
if mesh is None:
mesh = c10d._get_default_group()
if isinstance(mesh, dist.ProcessGroup):
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
else:
pg = mesh.get_group()
assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
world_size = dist.get_world_size(pg)
if world_size <= 1:
return F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)

out, lse = _templated_ring_attention(
mesh,
para_attn_ops.attention_forward_with_lse,
Expand Down Expand Up @@ -183,6 +164,23 @@ def ring_attn_func(

if mesh is None:
mesh = c10d._get_default_group()
if isinstance(mesh, dist.ProcessGroup):
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
else:
pg = mesh.get_group()
assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
world_size = dist.get_world_size(pg)
if world_size <= 1:
return F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)

return RingAttnFunc.apply(
query,
key,
Expand Down

0 comments on commit f4814a8

Please sign in to comment.