Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused Kernel of Conv1d and MM Get Wrong Results. #5630

Open
LKJacky opened this issue Jan 16, 2025 · 0 comments
Open

Fused Kernel of Conv1d and MM Get Wrong Results. #5630

LKJacky opened this issue Jan 16, 2025 · 0 comments
Labels

Comments

@LKJacky
Copy link

LKJacky commented Jan 16, 2025

Describe the bug

I face a strange bug. I make a fused kernel, including conv1d and a matrix multiplication. It get right result when enable TRITON_INTERPRET=1, but get wrong result when disable it.

Another point is when remove matrix multiplication from this kernel, It always get right results.

import torch
from torch import Tensor
from torch.autograd import Function
import triton
from triton import language as tl
import torch.nn.functional as F
import einops
import os


# os.environ["TRITON_INTERPRET"] = "1"


def custom_conv(x, wc):
    """
    x: [B, L, D]
    wc: [K, D]
    """
    K = wc.shape[0]
    y = 0
    for i in range(K - 1, -1, -1):
        y = y + x * wc[i : i + 1, :]
        x = torch.roll(x, shifts=1, dims=1)
        x[:, 0] = 0
    return y


def naive_func(x, w, wc):
    """
    q: [B, H, L, D]
    w: [H, D, ND]
    wc: [HND,K]
    """
    x = custom_conv(x, wc)
    x = x @ w
    return x


@triton.jit
def _fused_expand_kernel(
    x_ptr,  # [B, L, D]
    w_ptr,  # [D D]
    wc_ptr,  # [K D]
    out_ptr,  # [B, L, D]
    tmp_ptr,  # [B, 2BT, D]
    s_bld_b,
    s_bld_l,
    s_bld_d,
    s_dd_d1,
    s_dd_d2,
    s_kd_k,
    s_kd_d,
    s_btd_b,
    s_btd_l,
    s_btd_d,
    B: tl.constexpr,
    L: tl.constexpr,
    D: tl.constexpr,
    K: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    i_b = tl.program_id(0)
    if i_b >= B:
        return

    NL = triton.cdiv(L, BLOCK_SIZE)

    x_ptrs = tl.make_block_ptr(x_ptr + i_b * s_bld_b, [L, D], [s_bld_l, s_bld_d], [0, 0], [BLOCK_SIZE, D], [0, 1])
    w_ptrs = tl.make_block_ptr(w_ptr, [D, D], [s_dd_d1, s_dd_d2], [0, 0], [D, D], [0, 1])
    o_ptrs = tl.make_block_ptr(out_ptr + i_b * s_bld_b, [L, D], [s_bld_l, s_bld_d], [0, 0], [BLOCK_SIZE, D], [0, 1])
    wc_ptrs = tl.make_block_ptr(wc_ptr, [K, D], [s_kd_k, s_kd_d], [0, 0], [1, D], [0, 1])
    tmp_pre_ptrs = tl.make_block_ptr(
        tmp_ptr + i_b * s_btd_b, [2 * BLOCK_SIZE, D], [s_btd_l, s_btd_d], [0, 0], [BLOCK_SIZE, D], [0, 1]
    )

    tmp_cur_ptrs = tl.advance(tmp_pre_ptrs, [BLOCK_SIZE, 0])

    wc0 = tl.load(tl.advance(wc_ptrs, [3, 0]))
    wc1 = tl.load(tl.advance(wc_ptrs, [2, 0]))
    wc2 = tl.load(tl.advance(wc_ptrs, [1, 0]))
    wc3 = tl.load(wc_ptrs)

    w = tl.load(w_ptrs)
    for i in range(NL):
        x = tl.load(x_ptrs)

        # conv1d
        tl.debug_barrier()
        tl.store(tmp_cur_ptrs, x)
        tl.debug_barrier()

        x_c = (x * wc0).to(tl.float32)
        x_c += tl.load(tl.advance(tmp_cur_ptrs, [-1, 0])) * wc1
        x_c += tl.load(tl.advance(tmp_cur_ptrs, [-2, 0])) * wc2
        x_c += tl.load(tl.advance(tmp_cur_ptrs, [-3, 0])) * wc3
        tl.debug_barrier()
        tl.store(tmp_pre_ptrs, x)
        tl.debug_barrier()

        y = x_c
        y = tl.dot(y, w, allow_tf32=False)

        # store
        tl.store(o_ptrs, y)

        # advance ptr
        x_ptrs = tl.advance(x_ptrs, [BLOCK_SIZE, 0])
        o_ptrs = tl.advance(o_ptrs, [BLOCK_SIZE, 0])


def kernel(x, w, wc):
    B, L, D = x.shape
    K = wc.shape[-1]
    BLOCK_SIZE = 16
    out = torch.empty_like(x)
    tmp = x.new_empty([B, 2 * BLOCK_SIZE, D])

    _fused_expand_kernel[(B,)](
        x.contiguous(),
        w.contiguous(),
        wc.contiguous(),
        out,
        tmp,
        x.stride(0),
        x.stride(1),
        x.stride(2),
        w.stride(0),
        w.stride(1),
        wc.stride(0),
        wc.stride(1),
        tmp.stride(0),
        tmp.stride(1),
        tmp.stride(2),
        B,
        L,
        D,
        K,
        BLOCK_SIZE,
    )
    return out


if __name__ == "__main__":
    B, L, D = 4, 4096, 64
    K = 4
    x = torch.normal(0, 1, [B, L, D]).cuda()
    w = torch.normal(0, 1, [D, D]).cuda()
    wc = torch.normal(0, 1, [K, D]).cuda()
    o1 = naive_func(x, w, wc)
    o2 = kernel(x, w, wc)
    print((o1 - o2).abs().max())

Environment details

Triton 3.1.0
GPU: 3090

@LKJacky LKJacky added the bug label Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant