Skip to content

Commit

Permalink
Remove permutation of KV cache
Browse files Browse the repository at this point in the history
This PR removes the permutation of the key and
value cache prior to kernel invocation.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jan 21, 2025
1 parent 14b9bd2 commit f516332
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 73 deletions.
1 change: 1 addition & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ class MMA(CustomOp):
rhs: fx.Node
acc: fx.Node
mma_type: Optional["MMAType"] = None
dimensional_mapping: Optional[dict["MMAOperand", IndexSymbol]] = None

@property
def indexing_dims(self) -> list[IndexSymbol]:
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def emit_mfma(
@handle_op(mma)
def handle_mma(emitter: WaveEmitter, node: fx.Node):
try:
lhs, rhs, acc, mma_type = node.args
lhs, rhs, acc, mma_type, mma_mapping = node.args
acc = cast_vector(emitter, acc)
values = [cast_vector(emitter, val) for val in [lhs, rhs]]
except ValueError as e:
Expand Down
15 changes: 9 additions & 6 deletions iree/turbine/kernel/wave/templates/paged_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.constraints import MMAType, MMAOperand
from iree.turbine.kernel.wave.utils import (
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
Expand Down Expand Up @@ -149,7 +149,7 @@ def get_constraints(phase: Phase) -> list[tkw.Constraint]:
outputs={S: i, B: j, N: k},
)

K2_dim = k_shape[2]
K2_dim = k_shape[1]
# Returns the key for the given token index.
k_mapping = tkw.IndexMapping(
num_iterators=4,
Expand Down Expand Up @@ -177,12 +177,15 @@ def get_constraints(phase: Phase) -> list[tkw.Constraint]:
k_layout = tkl.MemoryLayout(shape=k_shape)
v_layout = tkl.MemoryLayout(shape=v_shape)
block_table_layout = tkl.MemoryLayout(shape=block_table_shape)
mma_mapping0 = {MMAOperand.M: K2, MMAOperand.N: B, MMAOperand.K: K1}
mma_mapping1 = {MMAOperand.M: N, MMAOperand.N: B, MMAOperand.K: K2}

# The kv-cache layout here is (SEQ, HEADS, HEAD_DIM).
@tkw.wave(get_constraints(Phase.PHASE_0))
def phase_0(
q: tkl.Memory[S, B, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[T, BH, K2, K1, ADDRESS_SPACE, tkl.f16, k_layout],
v: tkl.Memory[T, BH, N, K2, ADDRESS_SPACE, tkl.f16, v_layout],
k: tkl.Memory[T, K2, BH, K1, ADDRESS_SPACE, tkl.f16, k_layout],
v: tkl.Memory[T, K2, BH, N, ADDRESS_SPACE, tkl.f16, v_layout],
request_indices: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32],
sequence_lengths: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32],
block_table: tkl.Memory[
Expand Down Expand Up @@ -236,7 +239,7 @@ def loop(
mapping_dynamic_vals=(block_indices,),
)
imm_reg = tkl.Register[S, K2, B, tkl.f32](0.0)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, dimensional_mapping=mma_mapping0)
x_j = tkw.permute(inner_acc, target_shape=[S, B, K2])
m_j = tkw.max(x_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
Expand All @@ -251,7 +254,7 @@ def loop(
mapping_dynamic_vals=(block_indices,),
)
new_acc = acc * e_delta_max
acc = tkw.mma(v_reg, imm_f16, new_acc)
acc = tkw.mma(v_reg, imm_f16, new_acc, dimensional_mapping=mma_mapping1)
return m_j, d_j, acc

res_max, res_sum, res_mm = loop
Expand Down
5 changes: 5 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ def is_mma(node):
rhs_shape = custom.rhs_type.symbolic_shape
acc_shape = custom.acc_type.symbolic_shape
k = ((set(lhs_shape) & set(rhs_shape)) - set(acc_shape)).pop()
if custom.dimensional_mapping:
m = custom.dimensional_mapping[MMAOperand.M]
n = custom.dimensional_mapping[MMAOperand.N]
k = custom.dimensional_mapping[MMAOperand.K]

if custom not in mapping:
mapping[custom] = {}
mapping[custom][m] = MMAOperand.M
Expand Down
3 changes: 2 additions & 1 deletion lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,8 +1098,9 @@ def test_paged_flash_decoding():
shape = (B, BH, K1, N, K2, S, SEQ_LEN)
num_kv_splits = 8
U = num_kv_splits
# Physical shapes for kv-cache and block table
k_shape = (S, SEQ_LEN, BH, K1)
v_shape = (S, SEQ_LEN, N, K2)
v_shape = (S, SEQ_LEN, BH, N)
block_table_shape = (S, SEQ_LEN)
mfma_variant = tkw.MMAType.F32_16x16x16_F16
(
Expand Down
68 changes: 34 additions & 34 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,21 +309,21 @@ def test_gemm():
# CHECK-SAME: (%b, 4, None, (), None)

# CHECK-NEXT: %mma_M:0_N:0_K:0
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:0_K:1
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:1_K:0
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:1_K:1
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:0_K:0
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:0_K:1
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:1_K:0
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:1_K:1
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None, None)
# CHECK-NEXT: return [mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1]

# Custom format:
Expand Down Expand Up @@ -497,21 +497,21 @@ def test_batched_gemm():
# CHECK-SAME: (%b, 4, None, (), None)

# CHECK-NEXT: %mma_M:0_N:0_K:0
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:0_K:1
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:1_K:0
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:1_K:1
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:0_K:0
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:0_K:1
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:1_K:0
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None, None)
# CHECK-NEXT: %mma_M:1_N:1_K:1
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None)
# CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None, None)
# CHECK-NEXT: return [mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1]

# Custom format:
Expand Down Expand Up @@ -613,21 +613,21 @@ def test_gemm_non_direct_acc():
# CHECK: %add_M:1_N:1_K:0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_M:1_N:1_K:0, %acc_M:1_N:1_K:0), kwargs = {})
# CHECK: %mma_M:0_N:0_K:0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %add_M:0_N:0_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %add_M:0_N:0_K:0, None, None), kwargs = {})
# CHECK: %mma_M:0_N:0_K:1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None), kwargs = {})
# CHECK: %mma_M:0_N:1_K:0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %add_M:0_N:1_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %add_M:0_N:1_K:0, None, None), kwargs = {})
# CHECK: %mma_M:0_N:1_K:1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None, None), kwargs = {})
# CHECK: %mma_M:1_N:0_K:0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %add_M:1_N:0_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %add_M:1_N:0_K:0, None, None), kwargs = {})
# CHECK: %mma_M:1_N:0_K:1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None, None), kwargs = {})
# CHECK: %mma_M:1_N:1_K:0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %add_M:1_N:1_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %add_M:1_N:1_K:0, None, None), kwargs = {})
# CHECK: %mma_M:1_N:1_K:1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None, None), kwargs = {})


@tkw.wave_trace_only()
Expand Down Expand Up @@ -744,9 +744,9 @@ def test_gemm_reduction_expansion_only():
# CHECK-SAME: (%b, 4, None, (), None)

# CHECK-NEXT: %mma_M:0_N:0_K:0
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None)
# CHECK-NEXT: %mma_M:0_N:0_K:1
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None)
# CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None)

# CHECK-NEXT: return [mma_M:0_N:0_K:1]

Expand Down Expand Up @@ -1075,13 +1075,13 @@ def test_chained_gemm_32x32x8():
# CHECK: %read_shared_M:0_K2:0_K1:3
# CHECK-SAME: (args = (%k, 4, None, (), None)
# CHECK: %mma_M:0_K2:0_K1:0
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:0, %read_M:0_K2:0_K1:0, %register_M:0_K2:0_K1:0, None)
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:0, %read_M:0_K2:0_K1:0, %register_M:0_K2:0_K1:0, None, None)
# CHECK: %mma_M:0_K2:0_K1:1
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:1, %read_M:0_K2:0_K1:1, %mma_M:0_K2:0_K1:0, None)
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:1, %read_M:0_K2:0_K1:1, %mma_M:0_K2:0_K1:0, None, None)
# CHECK: %mma_M:0_K2:0_K1:2
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:2, %read_M:0_K2:0_K1:2, %mma_M:0_K2:0_K1:1, None)
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:2, %read_M:0_K2:0_K1:2, %mma_M:0_K2:0_K1:1, None, None)
# CHECK: %mma_M:0_K2:0_K1:3
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:3, %read_M:0_K2:0_K1:3, %mma_M:0_K2:0_K1:2, None)
# CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:3, %read_M:0_K2:0_K1:3, %mma_M:0_K2:0_K1:2, None, None)
# CHECK: %permute_M:0_K2:0
# CHECK-SAME: (args = (%mma_M:0_K2:0_K1:3, [B, M, K2])
# CHECK: %cast_M:0_K2:0
Expand All @@ -1104,13 +1104,13 @@ def test_chained_gemm_32x32x8():
# CHECK: %reshape_M:0_N:0_K2:3
# CHECK-SAME: (args = ([%cast_M:0_K2:0], {K2: 32, M: 32, K1: 8, B: 0})
# CHECK: %mma_M:0_N:0_K2:0
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:0, %read_shared_M:0_N:0_K2:0, %acc_M:0_N:0_K2:0, None)
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:0, %read_shared_M:0_N:0_K2:0, %acc_M:0_N:0_K2:0, None, None)
# CHECK: %mma_M:0_N:0_K2:1
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:1, %read_shared_M:0_N:0_K2:1, %mma_M:0_N:0_K2:0, None)
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:1, %read_shared_M:0_N:0_K2:1, %mma_M:0_N:0_K2:0, None, None)
# CHECK: %mma_M:0_N:0_K2:2
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:2, %read_shared_M:0_N:0_K2:2, %mma_M:0_N:0_K2:1, None)
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:2, %read_shared_M:0_N:0_K2:2, %mma_M:0_N:0_K2:1, None, None)
# CHECK: %mma_M:0_N:0_K2:3
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:3, %read_shared_M:0_N:0_K2:3, %mma_M:0_N:0_K2:2, None)
# CHECK-SAME: (args = (%reshape_M:0_N:0_K2:3, %read_shared_M:0_N:0_K2:3, %mma_M:0_N:0_K2:2, None, None)
# CHECK: return [mma_M:0_N:0_K2:3]


Expand Down
2 changes: 1 addition & 1 deletion lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_gemm():
# CHECK-NEXT: %read_3
# CHECK-SAME: (%allocate_1, 4, None, (), [%write_1])
# CHECK-NEXT: %mma
# CHECK-SAME: (%read_2, %read_3, %acc, None)
# CHECK-SAME: (%read_2, %read_3, %acc, None, None)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f516332

Please sign in to comment.