Skip to content

Commit

Permalink
fix assert_close
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 5, 2024
1 parent 761968c commit 60045d1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ with torch.no_grad(), torch.cuda.device(rank):
is_causal=is_causal,
)

torch.testing.assert_close(out_slice, out_slice_ref)
torch.testing.assert_close(out_slice, out_slice_ref, rtol=1e-5, atol=1e-3 * world_size)

dist.destroy_process_group()
```
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _test_attn_func(self, dtype, device, B, H, S_Q, S_KV, D, is_causal, compile)
is_causal=is_causal,
).chunk(self.world_size, dim=-2)[self.rank]

torch.testing.assert_close(out_slice, out_slice_ref)
torch.testing.assert_close(out_slice, out_slice_ref, rtol=1e-5, atol=1e-3 * self.world_size)

def _test_attn_mode(self, dtype, device, B, H, S_Q, S_KV, D, is_causal, compile):
if is_causal and S_Q != S_KV:
Expand Down Expand Up @@ -115,7 +115,7 @@ def func(*args, **kwargs):
is_causal=is_causal,
).chunk(self.world_size, dim=-2)[self.rank]

torch.testing.assert_close(out_slice, out_slice_ref)
torch.testing.assert_close(out_slice, out_slice_ref, rtol=1e-5, atol=1e-3 * self.world_size)


class RingAttnTest(ParallelAttnTest):
Expand Down

0 comments on commit 60045d1

Please sign in to comment.