You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I'm currently developing custom fused kernel based on flash attention introduced in triton tutorials. While doing this, I found that there might be a bug in backward when causal = False. Simply, I compare triton's forward/backward output with naive attention output (I don't change any kernel). When causal is True, everything is ok but gradient of qkv is far different from naive one if causal is False.
Describe the bug
Hi! I'm currently developing custom fused kernel based on flash attention introduced in triton tutorials. While doing this, I found that there might be a bug in backward when causal = False. Simply, I compare triton's forward/backward output with naive attention output (I don't change any kernel). When causal is True, everything is ok but gradient of qkv is far different from naive one if causal is False.
Python code
'''
(triton fused attention code https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html)
'''
=== Results ===
(when causal = True)
[Forward] max diff = 0.001953125
[Backward] dQ diff = 0.03125
[Backward] dK diff = 0.02734375
[Backward] dV diff = 0.0234375
(when cuasal = False)
[Forward] max diff = 0.001953125
[Backward] dQ diff = 9.2890625
[Backward] dK diff = 11.0546875
[Backward] dV diff = 5.125
Environment details
Triton : 3.1.0
GPU : A100
python : 3.11.11
torch : 2.5.1+cu124
The text was updated successfully, but these errors were encountered: