Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusedAttention Bug (at backward with causal=False) #5660

Open
SanggeunParrk opened this issue Jan 21, 2025 · 0 comments
Open

FusedAttention Bug (at backward with causal=False) #5660

SanggeunParrk opened this issue Jan 21, 2025 · 0 comments
Labels

Comments

@SanggeunParrk
Copy link

SanggeunParrk commented Jan 21, 2025

Describe the bug

Hi! I'm currently developing custom fused kernel based on flash attention introduced in triton tutorials. While doing this, I found that there might be a bug in backward when causal = False. Simply, I compare triton's forward/backward output with naive attention output (I don't change any kernel). When causal is True, everything is ok but gradient of qkv is far different from naive one if causal is False.

Python code

'''
(triton fused attention code https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html)
'''

# 1) Naive Reference Implementation of Attention
def naive_attention(q, k, v, causal: bool, sm_scale: float):
    """
    A simple PyTorch reference for checking correctness.

    Shapes (for all arguments):
      - q, k, v:  [B, H, N_CTX, D]
      - bias:     [B, H, N_CTX, N_CTX]
    """
    B, H, N, D = q.shape
    # 1) Compute scaled dot-product: [B,H,N,D] x [B,H,D,N] -> [B,H,N,N]
    logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale

    # 2) Causal mask if needed (lower-triangular)
    if causal:
        # Typical approach: fill the upper-triangle with -inf
        mask = torch.triu(torch.ones(N, N, device=q.device, dtype=torch.bool), diagonal=1)
        logits = logits.masked_fill(mask, float('-inf'))

    # 3) Softmax over last dim
    attn = torch.softmax(logits, dim=-1).to(q.dtype)

    # 4) Multiply by V -> shape [B,H,N,D]
    out = torch.matmul(attn.float(), v.float()).to(q.dtype)
    return out

# 2) Test Function

def test_attention_with():
    """
    - Creates small test tensors for q/k/v/bias.
    - Compares forward & backward pass with naive reference.
    - Prints out the differences.
    """
    # Small example shapes
    B, H, N_CTX, D = 2, 8, 128, 16
    causal = True
    sm_scale = 1.0

    torch.manual_seed(0)
    # Create random q/k/v/bias for the *reference* pass
    q_ref = torch.randn((B, H, N_CTX, D), dtype=torch.float16, device='cuda', requires_grad=True)
    k_ref = torch.randn_like(q_ref, requires_grad=True)
    v_ref = torch.randn_like(q_ref, requires_grad=True)

    # Create *separate copies* for the Triton-based kernel
    q_tri = q_ref.clone().detach().requires_grad_()
    k_tri = k_ref.clone().detach().requires_grad_()
    v_tri = v_ref.clone().detach().requires_grad_()

    # 2a) Forward pass: reference
    ref_out = naive_attention(q_ref, k_ref, v_ref, causal, sm_scale)

    # 2b) Forward pass: Triton kernel
    #     (You must have defined `attention_with_bias` from your integrated code.)
    tri_out = attention(q_tri, k_tri, v_tri, causal, sm_scale)

    # Compare forward results
    fwd_diff = (ref_out - tri_out).abs().max().item()
    print(f"[Forward] max diff = {fwd_diff}")

    #
    # 3) Backward pass
    #
    # We'll accumulate gradient from a random 'dout'
    grad_out = torch.randn_like(ref_out)
    grad_out2 = grad_out.clone().detach()
    ref_out.backward(grad_out, retain_graph=True)
    tri_out.backward(grad_out2, retain_graph=True)

    # Compare gradients
    dq_diff    = (q_ref.grad    - q_tri.grad).abs().max().item()
    dk_diff    = (k_ref.grad    - k_tri.grad).abs().max().item()
    dv_diff    = (v_ref.grad    - v_tri.grad).abs().max().item()

    print(f"[Backward] dQ diff    = {dq_diff}")
    print(f"[Backward] dK diff    = {dk_diff}")
    print(f"[Backward] dV diff    = {dv_diff}")

if __name__ == "__main__":
    test_attention()

=== Results ===
(when causal = True)
[Forward] max diff = 0.001953125
[Backward] dQ diff = 0.03125
[Backward] dK diff = 0.02734375
[Backward] dV diff = 0.0234375

(when cuasal = False)
[Forward] max diff = 0.001953125
[Backward] dQ diff = 9.2890625
[Backward] dK diff = 11.0546875
[Backward] dV diff = 5.125

Environment details

Triton : 3.1.0
GPU : A100
python : 3.11.11
torch : 2.5.1+cu124

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant