diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 6e3a41d1b..a3cdd8f16 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -941,6 +941,7 @@ class MMA(CustomOp): rhs: fx.Node acc: fx.Node mma_type: Optional["MMAType"] = None + dimensional_mapping: Optional[dict["MMAOperand", IndexSymbol]] = None @property def indexing_dims(self) -> list[IndexSymbol]: diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 468cf0d19..2e8519f0d 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1004,7 +1004,7 @@ def emit_mfma( @handle_op(mma) def handle_mma(emitter: WaveEmitter, node: fx.Node): try: - lhs, rhs, acc, mma_type = node.args + lhs, rhs, acc, mma_type, mma_mapping = node.args acc = cast_vector(emitter, acc) values = [cast_vector(emitter, val) for val in [lhs, rhs]] except ValueError as e: diff --git a/iree/turbine/kernel/wave/templates/paged_decode_attention.py b/iree/turbine/kernel/wave/templates/paged_decode_attention.py index c1f01cfc5..6427b2bcd 100644 --- a/iree/turbine/kernel/wave/templates/paged_decode_attention.py +++ b/iree/turbine/kernel/wave/templates/paged_decode_attention.py @@ -7,7 +7,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.lang.global_symbols import * -from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.constraints import MMAType, MMAOperand from iree.turbine.kernel.wave.utils import ( get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, @@ -149,7 +149,7 @@ def get_constraints(phase: Phase) -> list[tkw.Constraint]: outputs={S: i, B: j, N: k}, ) - K2_dim = k_shape[2] + K2_dim = k_shape[1] # Returns the key for the given token index. k_mapping = tkw.IndexMapping( num_iterators=4, @@ -177,12 +177,15 @@ def get_constraints(phase: Phase) -> list[tkw.Constraint]: k_layout = tkl.MemoryLayout(shape=k_shape) v_layout = tkl.MemoryLayout(shape=v_shape) block_table_layout = tkl.MemoryLayout(shape=block_table_shape) + mma_mapping0 = {MMAOperand.M: K2, MMAOperand.N: B, MMAOperand.K: K1} + mma_mapping1 = {MMAOperand.M: N, MMAOperand.N: B, MMAOperand.K: K2} + # The kv-cache layout here is (SEQ, HEADS, HEAD_DIM). @tkw.wave(get_constraints(Phase.PHASE_0)) def phase_0( q: tkl.Memory[S, B, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[T, BH, K2, K1, ADDRESS_SPACE, tkl.f16, k_layout], - v: tkl.Memory[T, BH, N, K2, ADDRESS_SPACE, tkl.f16, v_layout], + k: tkl.Memory[T, K2, BH, K1, ADDRESS_SPACE, tkl.f16, k_layout], + v: tkl.Memory[T, K2, BH, N, ADDRESS_SPACE, tkl.f16, v_layout], request_indices: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32], sequence_lengths: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32], block_table: tkl.Memory[ @@ -236,7 +239,7 @@ def loop( mapping_dynamic_vals=(block_indices,), ) imm_reg = tkl.Register[S, K2, B, tkl.f32](0.0) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, dimensional_mapping=mma_mapping0) x_j = tkw.permute(inner_acc, target_shape=[S, B, K2]) m_j = tkw.max(x_j, partial_max, dim=K2) e_delta_max = tkw.exp2(partial_max - m_j) @@ -251,7 +254,7 @@ def loop( mapping_dynamic_vals=(block_indices,), ) new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) + acc = tkw.mma(v_reg, imm_f16, new_acc, dimensional_mapping=mma_mapping1) return m_j, d_j, acc res_max, res_sum, res_mm = loop diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index c8cb890fb..676f39a9a 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -314,6 +314,11 @@ def is_mma(node): rhs_shape = custom.rhs_type.symbolic_shape acc_shape = custom.acc_type.symbolic_shape k = ((set(lhs_shape) & set(rhs_shape)) - set(acc_shape)).pop() + if custom.dimensional_mapping: + m = custom.dimensional_mapping[MMAOperand.M] + n = custom.dimensional_mapping[MMAOperand.N] + k = custom.dimensional_mapping[MMAOperand.K] + if custom not in mapping: mapping[custom] = {} mapping[custom][m] = MMAOperand.M diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 39ccd5ca2..27034e8e4 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -1098,8 +1098,9 @@ def test_paged_flash_decoding(): shape = (B, BH, K1, N, K2, S, SEQ_LEN) num_kv_splits = 8 U = num_kv_splits + # Physical shapes for kv-cache and block table k_shape = (S, SEQ_LEN, BH, K1) - v_shape = (S, SEQ_LEN, N, K2) + v_shape = (S, SEQ_LEN, BH, N) block_table_shape = (S, SEQ_LEN) mfma_variant = tkw.MMAType.F32_16x16x16_F16 ( diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index e0aa960fe..dfecc6621 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -309,21 +309,21 @@ def test_gemm(): # CHECK-SAME: (%b, 4, None, (), None) # CHECK-NEXT: %mma_M:0_N:0_K:0 - # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:0_K:1 - # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:1_K:0 - # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:1_K:1 - # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:0_K:0 - # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:0_K:1 - # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:1_K:0 - # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:1_K:1 - # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None, None) # CHECK-NEXT: return [mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1] # Custom format: @@ -497,21 +497,21 @@ def test_batched_gemm(): # CHECK-SAME: (%b, 4, None, (), None) # CHECK-NEXT: %mma_M:0_N:0_K:0 - # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:0_K:1 - # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:1_K:0 - # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:1_K:1 - # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:0_K:0 - # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:0_K:1 - # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:1_K:0 - # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None, None) # CHECK-NEXT: %mma_M:1_N:1_K:1 - # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None) + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None, None) # CHECK-NEXT: return [mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1] # Custom format: @@ -613,21 +613,21 @@ def test_gemm_non_direct_acc(): # CHECK: %add_M:1_N:1_K:0 # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_M:1_N:1_K:0, %acc_M:1_N:1_K:0), kwargs = {}) # CHECK: %mma_M:0_N:0_K:0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %add_M:0_N:0_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %add_M:0_N:0_K:0, None, None), kwargs = {}) # CHECK: %mma_M:0_N:0_K:1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None), kwargs = {}) # CHECK: %mma_M:0_N:1_K:0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %add_M:0_N:1_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %add_M:0_N:1_K:0, None, None), kwargs = {}) # CHECK: %mma_M:0_N:1_K:1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None, None), kwargs = {}) # CHECK: %mma_M:1_N:0_K:0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %add_M:1_N:0_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %add_M:1_N:0_K:0, None, None), kwargs = {}) # CHECK: %mma_M:1_N:0_K:1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None, None), kwargs = {}) # CHECK: %mma_M:1_N:1_K:0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %add_M:1_N:1_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %add_M:1_N:1_K:0, None, None), kwargs = {}) # CHECK: %mma_M:1_N:1_K:1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None, None), kwargs = {}) @tkw.wave_trace_only() @@ -744,9 +744,9 @@ def test_gemm_reduction_expansion_only(): # CHECK-SAME: (%b, 4, None, (), None) # CHECK-NEXT: %mma_M:0_N:0_K:0 - # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M:0_N:0_K:1 - # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None, None) # CHECK-NEXT: return [mma_M:0_N:0_K:1] @@ -1075,13 +1075,13 @@ def test_chained_gemm_32x32x8(): # CHECK: %read_shared_M:0_K2:0_K1:3 # CHECK-SAME: (args = (%k, 4, None, (), None) # CHECK: %mma_M:0_K2:0_K1:0 - # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:0, %read_M:0_K2:0_K1:0, %register_M:0_K2:0_K1:0, None) + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:0, %read_M:0_K2:0_K1:0, %register_M:0_K2:0_K1:0, None, None) # CHECK: %mma_M:0_K2:0_K1:1 - # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:1, %read_M:0_K2:0_K1:1, %mma_M:0_K2:0_K1:0, None) + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:1, %read_M:0_K2:0_K1:1, %mma_M:0_K2:0_K1:0, None, None) # CHECK: %mma_M:0_K2:0_K1:2 - # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:2, %read_M:0_K2:0_K1:2, %mma_M:0_K2:0_K1:1, None) + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:2, %read_M:0_K2:0_K1:2, %mma_M:0_K2:0_K1:1, None, None) # CHECK: %mma_M:0_K2:0_K1:3 - # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:3, %read_M:0_K2:0_K1:3, %mma_M:0_K2:0_K1:2, None) + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:3, %read_M:0_K2:0_K1:3, %mma_M:0_K2:0_K1:2, None, None) # CHECK: %permute_M:0_K2:0 # CHECK-SAME: (args = (%mma_M:0_K2:0_K1:3, [B, M, K2]) # CHECK: %cast_M:0_K2:0 @@ -1104,13 +1104,13 @@ def test_chained_gemm_32x32x8(): # CHECK: %reshape_M:0_N:0_K2:3 # CHECK-SAME: (args = ([%cast_M:0_K2:0], {K2: 32, M: 32, K1: 8, B: 0}) # CHECK: %mma_M:0_N:0_K2:0 - # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:0, %read_shared_M:0_N:0_K2:0, %acc_M:0_N:0_K2:0, None) + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:0, %read_shared_M:0_N:0_K2:0, %acc_M:0_N:0_K2:0, None, None) # CHECK: %mma_M:0_N:0_K2:1 - # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:1, %read_shared_M:0_N:0_K2:1, %mma_M:0_N:0_K2:0, None) + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:1, %read_shared_M:0_N:0_K2:1, %mma_M:0_N:0_K2:0, None, None) # CHECK: %mma_M:0_N:0_K2:2 - # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:2, %read_shared_M:0_N:0_K2:2, %mma_M:0_N:0_K2:1, None) + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:2, %read_shared_M:0_N:0_K2:2, %mma_M:0_N:0_K2:1, None, None) # CHECK: %mma_M:0_N:0_K2:3 - # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:3, %read_shared_M:0_N:0_K2:3, %mma_M:0_N:0_K2:2, None) + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:3, %read_shared_M:0_N:0_K2:3, %mma_M:0_N:0_K2:2, None, None) # CHECK: return [mma_M:0_N:0_K2:3] diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 7784885c8..1eccec311 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -210,7 +210,7 @@ def test_gemm(): # CHECK-NEXT: %read_3 # CHECK-SAME: (%allocate_1, 4, None, (), [%write_1]) # CHECK-NEXT: %mma - # CHECK-SAME: (%read_2, %read_3, %acc, None) + # CHECK-SAME: (%read_2, %read_3, %acc, None, None) if __name__ == "__main__": diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 7b4405eb8..d4a10b86a 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -124,7 +124,7 @@ def test_gemm_pipelined(): # CHECK-NEXT: %rotating_reg_5 # CHECK-NEXT: %rotating_reg_6 # CHECK-NEXT: %mma_M_1_N_1_K_1 - # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6, None) + # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6, None, None) # CHECK-NEXT: %read_shared_M_0_N_0_K_0 # CHECK-NEXT: %read_shared_M_0_N_0_K_1 # CHECK-NEXT: %read_20 @@ -134,29 +134,29 @@ def test_gemm_pipelined(): # CHECK-NEXT: %read_shared_M_1_N_0_K_0 # CHECK-NEXT: %read_shared_M_1_N_0_K_1 # CHECK-NEXT: %mma_M_0_N_0_K_0 - # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_0_K_1, %acc_m_0_n_0_k_0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_0_K_1, %acc_m_0_n_0_k_0, None, None) # CHECK-NEXT: %mma_M_0_N_1_K_0 - # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %rotating_reg_3, %acc_m_0_n_1_k_0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %rotating_reg_3, %acc_m_0_n_1_k_0, None, None) # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.READ_SHARED: 2, Operation.MMA: 2}, 0) # CHECK-NEXT: %mma_M_0_N_0_K_1 - # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_M_0_N_0_K_0, None) + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_M_0_N_0_K_0, None, None) # CHECK-NEXT: %mma_M_1_N_0_K_0 - # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_0_K_1, %acc_m_1_n_0_k_0, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_0_K_1, %acc_m_1_n_0_k_0, None, None) # CHECK-NEXT: %write_10 # CHECK-NEXT: %write_11 # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.MMA: 2, Operation.WRITE_SHARED: 2}, 0) # CHECK-NEXT: %mma_M_0_N_1_K_1 - # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_M_0_N_1_K_0, None) + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_M_0_N_1_K_0, None, None) # CHECK-NEXT: %mma_M_1_N_0_K_1 - # CHECK-SAME: (%read_shared_M_1_N_0_K_1, %rotating_reg_2, %mma_M_1_N_0_K_0, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_1, %rotating_reg_2, %mma_M_1_N_0_K_0, None, None) # CHECK-NEXT: %read_shared_M_0_N_1_K_0 # CHECK-NEXT: %read_shared_M_0_N_1_K_1 # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.MMA: 2, Operation.READ_SHARED: 2}, 0) # CHECK-NEXT: %mma_M_1_N_1_K_0 - # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %rotating_reg_3, %mma_M_1_N_1_K_1, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %rotating_reg_3, %mma_M_1_N_1_K_1, None, None) # CHECK-NEXT: %read_shared_M_0_N_0_K_2 # CHECK-NEXT: %read_shared_M_0_N_0_K_3 # CHECK-NEXT: %scheduling_group_barrier @@ -188,23 +188,23 @@ def test_gemm_pipelined(): # CHECK-NEXT: %read_shared_M_1_N_0_K_0 # CHECK-NEXT: %read_shared_M_1_N_0_K_1 # CHECK-NEXT: %mma_M_0_N_0_K_0 - # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_0_K_3, %register_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_0_K_3, %register_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M_0_N_1_K_0 - # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_1_K_0, %register_M:0_N:1_K:0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_1_K_0, %register_M:0_N:1_K:0, None, None) # CHECK-NEXT: %mma_M_0_N_0_K_1 - # CHECK-SAME: (%read_shared_M_0_N_0_K_1, %read_shared_M_0_N_0_K_2, %mma_M_0_N_0_K_0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_1, %read_shared_M_0_N_0_K_2, %mma_M_0_N_0_K_0, None, None) # CHECK-NEXT: %mma_M_1_N_0_K_0 - # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_0_K_3, %register_M:1_N:0_K:0, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_0_K_3, %register_M:1_N:0_K:0, None, None) # CHECK-NEXT: %write_12 # CHECK-NEXT: %write_13 # CHECK-NEXT: %mma_M_0_N_1_K_1 - # CHECK-SAME: (%read_shared_M_0_N_0_K_1, %read_shared_M_0_N_1_K_1, %mma_M_0_N_1_K_0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_1, %read_shared_M_0_N_1_K_1, %mma_M_0_N_1_K_0, None, None) # CHECK-NEXT: %mma_M_1_N_0_K_1 - # CHECK-SAME: (%read_shared_M_1_N_0_K_1, %read_shared_M_0_N_0_K_2, %mma_M_1_N_0_K_0, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_1, %read_shared_M_0_N_0_K_2, %mma_M_1_N_0_K_0, None, None) # CHECK-NEXT: %read_shared_M_0_N_1_K_2 # CHECK-NEXT: %read_shared_M_0_N_1_K_3 # CHECK-NEXT: %mma_M_1_N_1_K_0 - # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_1_K_0, %register_M:1_N:1_K:0, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_1_K_0, %register_M:1_N:1_K:0, None, None) # CHECK-NEXT: %read_shared_M_0_N_0_K_4 # CHECK-NEXT: %read_shared_M_0_N_0_K_5 # CHECK-NEXT: %reduction_1 @@ -220,27 +220,27 @@ def test_gemm_pipelined(): # CHECK-NEXT: %get_result_14 # CHECK-NEXT: %get_result_15 # CHECK-NEXT: %mma_M_1_N_1_K_1 - # CHECK-SAME: (%get_result_10, %get_result_13, %get_result_15, None) + # CHECK-SAME: (%get_result_10, %get_result_13, %get_result_15, None, None) # CHECK-NEXT: %read_shared_M_0_N_0_K_6 # CHECK-NEXT: %read_shared_M_0_N_0_K_7 # CHECK-NEXT: %read_shared_M_1_N_0_K_2 # CHECK-NEXT: %read_shared_M_1_N_0_K_3 # CHECK-NEXT: %mma_M_0_N_0_K_2 - # CHECK-SAME: (%read_shared_M_0_N_0_K_6, %read_shared_M_0_N_0_K_7, %getresult_M:0_N:0_K:0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_6, %read_shared_M_0_N_0_K_7, %getresult_M:0_N:0_K:0, None, None) # CHECK-NEXT: %mma_M_0_N_1_K_2 - # CHECK-SAME: (%read_shared_M_0_N_0_K_6, %get_result_12, %getresult_M:0_N:1_K:0, None) + # CHECK-SAME: (%read_shared_M_0_N_0_K_6, %get_result_12, %getresult_M:0_N:1_K:0, None, None) # CHECK-NEXT: %mma_M_0_N_0_K_3 - # CHECK-SAME: (%get_result_9, %get_result_11, %mma_M_0_N_0_K_2, None) + # CHECK-SAME: (%get_result_9, %get_result_11, %mma_M_0_N_0_K_2, None, None) # CHECK-NEXT: %mma_M_1_N_0_K_2 - # CHECK-SAME: (%read_shared_M_1_N_0_K_2, %read_shared_M_0_N_0_K_7, %getresult_M:1_N:0_K:0, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_2, %read_shared_M_0_N_0_K_7, %getresult_M:1_N:0_K:0, None, None) # CHECK-NEXT: %mma_M_0_N_1_K_3 - # CHECK-SAME: (%get_result_9, %get_result_14, %mma_M_0_N_1_K_2, None) + # CHECK-SAME: (%get_result_9, %get_result_14, %mma_M_0_N_1_K_2, None, None) # CHECK-NEXT: %mma_M_1_N_0_K_3 - # CHECK-SAME: (%read_shared_M_1_N_0_K_3, %get_result_11, %mma_M_1_N_0_K_2, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_3, %get_result_11, %mma_M_1_N_0_K_2, None, None) # CHECK-NEXT: %mma_M_1_N_1_K_2 - # CHECK-SAME: (%read_shared_M_1_N_0_K_2, %get_result_12, %mma_M_1_N_1_K_1, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_2, %get_result_12, %mma_M_1_N_1_K_1, None, None) # CHECK-NEXT: %mma_M_1_N_1_K_3 - # CHECK-SAME: (%read_shared_M_1_N_0_K_3, %get_result_14, %mma_M_1_N_1_K_2, None) + # CHECK-SAME: (%read_shared_M_1_N_0_K_3, %get_result_14, %mma_M_1_N_1_K_2, None, None) # CHECK-NEXT: %write_M:0_N:0_K:0 # CHECK-NEXT: %write_M:0_N:1_K:0 # CHECK-NEXT: %write_M:1_N:0_K:0 diff --git a/tests/kernel/wave/attention/paged_attention_test.py b/tests/kernel/wave/attention/paged_attention_test.py index 69cb6b1ee..90a8ad605 100644 --- a/tests/kernel/wave/attention/paged_attention_test.py +++ b/tests/kernel/wave/attention/paged_attention_test.py @@ -204,12 +204,8 @@ def testPagedFlashDecoding( num_kv_heads = key_cache.shape[2] head_size_kv = value_cache.shape[3] - permuted_key_cache = key_cache.view(num_seqs, -1, num_kv_heads, head_size).permute( - [0, 2, 1, 3] - ) - permuted_value_cache = value_cache.view( - num_seqs, -1, num_kv_heads, head_size_kv - ).permute([0, 2, 3, 1]) + permuted_key_cache = key_cache.view(num_seqs, -1, num_kv_heads, head_size) + permuted_value_cache = value_cache.view(num_seqs, -1, num_kv_heads, head_size_kv) # Run the wave kernel. (