Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transforms: (set_memory_layout) only apply on streaming_regions #334

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 29 additions & 48 deletions compiler/transforms/set_memory_layout.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import builtin, linalg
from xdsl.dialects import builtin
from xdsl.parser import MemRefType
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand All @@ -22,23 +22,26 @@

class AddMemoryLayoutSIMD(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, linalg_op: linalg.GenericOp, rewriter: PatternRewriter):
def match_and_rewrite(
self, op: stream.StreamingRegionOp, rewriter: PatternRewriter
):
# check if operation is dispatched via library call, as set by e.g.
# the dispatch-kernels pass
if linalg_op.library_call is None:

if op.accelerator is None:
return
else:
library_call = linalg_op.library_call.data
library_call = op.accelerator.data

# check for library call
if library_call == "snax_gemmx_stream":
if not isinstance(linalg_op.body.block.first_op, RescaleOp):
if library_call == "snax_gemmx":
if not isinstance(op.body.block.first_op.body.block.first_op, RescaleOp):
return

shaped_operands: list[MemRefType] = [
op.type
for op in linalg_op.operands
if isinstance(op.type, builtin.MemRefType)
operand.type
for operand in op.operands
if isinstance(operand.type, builtin.MemRefType)
]

m = shaped_operands[0].get_shape()[0]
Expand Down Expand Up @@ -83,27 +86,17 @@ def match_and_rewrite(self, linalg_op: linalg.GenericOp, rewriter: PatternRewrit

# insert layout_cast ops
new_input_a = LayoutCast.from_type_and_target_layout(
linalg_op.inputs[0], tsl_input
op.inputs[0], tsl_input
)

new_output = LayoutCast.from_type_and_target_layout(
linalg_op.outputs[0], tsl_output
)

new_linalg_op = linalg.GenericOp(
inputs=[new_input_a.dest],
outputs=[new_output.dest],
body=rewriter.move_region_contents_to_new_regions(linalg_op.regions[0]),
indexing_maps=linalg_op.indexing_maps,
iterator_types=linalg_op.iterator_types,
doc=linalg_op.doc,
library_call=linalg_op.library_call,
op.outputs[0], tsl_output
)

rewriter.insert_op_before_matched_op([new_input_a, new_output])
rewriter.replace_op(linalg_op, new_linalg_op)
rewriter.insert_op([new_input_a, new_output], InsertPoint.before(op))

pass
op.operands[0] = new_input_a.dest
op.operands[1] = new_output.dest


class GemmLayout(StrEnum):
Expand All @@ -127,41 +120,29 @@ class AddMemoryLayout(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: linalg.GenericOp | stream.StreamingRegionOp, rewriter: PatternRewriter
self, op: stream.StreamingRegionOp, rewriter: PatternRewriter
):
# check if operation is dispatched via library call, as set by e.g.
# the dispatch-kernels pass

if isinstance(op, linalg.GenericOp):
if op.library_call is None:
return
else:
library_call = op.library_call.data
elif isinstance(op, stream.StreamingRegionOp):
if op.accelerator is None:
return
else:
library_call = op.accelerator.data
if op.accelerator is None:
return
else:
library_call = op.accelerator.data

has_add_c = False

# check for library call
if library_call == "snax_gemmx" or library_call == "snax_gemmx_stream":
# only do so for qmac kernels
if isinstance(op, linalg.GenericOp):
if not isinstance(op.body.block.first_op, QMacOp):
return
elif isinstance(op, stream.StreamingRegionOp):
assert isinstance(
generic_op := op.body.block.first_op, stream.GenericOp
)
if not isinstance(generic_op.body.block.first_op, QMacOp):
return
assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
jorendumoulin marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(generic_op.body.block.first_op, QMacOp):
return

if isinstance(generic_op.next_op, stream.GenericOp):
if isinstance(generic_op.next_op.body.block.first_op, AddOp):
# gemm
has_add_c = True
if isinstance(generic_op.next_op, stream.GenericOp):
if isinstance(generic_op.next_op.body.block.first_op, AddOp):
# gemm
has_add_c = True

# the layout should be as static as the memref is. no more, no less
# get m, n, k
Expand Down
2 changes: 1 addition & 1 deletion kernels/rescale/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ config["snaxoptflags"] = ",".join(
"insert-accfg-op{accelerator=snax_gemmx}",
"dispatch-kernels",
"set-memory-space",
"convert-linalg-to-stream",
"set-memory-layout",
"realize-memref-casts",
"insert-sync-barrier",
"dispatch-regions{nb_cores=2}",
"convert-linalg-to-stream",
"convert-stream-to-snax-stream",
"convert-linalg-to-accfg",
"convert-accfg-to-csr",
Expand Down
24 changes: 0 additions & 24 deletions tests/filecheck/transforms/set-memory-layout.mlir
Original file line number Diff line number Diff line change
@@ -1,30 +1,6 @@
// RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-layout --print-op-generic | filecheck %s
// RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-layout{gemm_layout=banked} --print-op-generic | filecheck %s --check-prefix=BANKED

builtin.module {
func.func @mnist(%arg0 : memref<?x128xi8, 1 : i32>, %arg1 : memref<128x128xi8, 1 : i32>, %arg2 : memref<?x128xi32, 1 : i32>) -> memref<?x128xi32, 1 : i32> {
%0 = arith.constant 1 : i32
%1 = arith.constant 1 : i32
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], library_call = "snax_gemmx"} ins(%arg0, %arg1, %0, %1 : memref<?x128xi8, 1 : i32>, memref<128x128xi8, 1 : i32>, i32, i32) outs(%arg2 : memref<?x128xi32, 1 : i32>) {
^0(%arg3 : i8, %arg4 : i8, %arg5 : i32, %arg6 : i32, %arg7 : i32):
%2 = kernel.qmac %arg3, %arg4 zp_lhs : %arg5 zp_rhs : %arg6 : i8, i8, i32, i32 -> i32
linalg.yield %2 : i32
}
func.return %arg2 : memref<?x128xi32, 1 : i32>
}
}


//CHECK: %2 = "snax.layout_cast"(%arg0) : (memref<?x128xi8, 1 : i32>) -> memref<?x128xi8, #tsl.tsl<[?, 8] -> (1024, 8), [16, 8] -> (64, 1)>, 1 : i32>
//CHECK-NEXT: %3 = "snax.layout_cast"(%arg1) : (memref<128x128xi8, 1 : i32>) -> memref<128x128xi8, #tsl.tsl<[16, 8] -> (64, 1), [16, 8] -> (1024, 8)>, 1 : i32>
//CHECK-NEXT: %4 = "snax.layout_cast"(%arg2) : (memref<?x128xi32, 1 : i32>) -> memref<?x128xi32, #tsl.tsl<[?, 8] -> (1024, 8), [16, 8] -> (64, 1)>, 1 : i32>

//BANKED: %2 = "snax.layout_cast"(%arg0) : (memref<?x128xi8, 1 : i32>) -> memref<?x128xi8, #tsl.tsl<[?, 8] -> (4096, 8), [16, 8] -> (256, 1)>, 1 : i32>
//BANKED-NEXT: %3 = "snax.layout_cast"(%arg1) : (memref<128x128xi8, 1 : i32>) -> memref<128x128xi8, #tsl.tsl<[16, 8] -> (256, 1), [16, 8] -> (4096, 8)>, 1 : i32>
//BANKED-NEXT: %4 = "snax.layout_cast"(%arg2) : (memref<?x128xi32, 1 : i32>) -> memref<?x128xi32, #tsl.tsl<[?, 8] -> (1024, 8), [16, 8] -> (64, 1)>, 1 : i32>

// -----

func.func @gemm(%arg0 : memref<16x16xi8, "L1">, %arg1 : memref<16x16xi8, "L1">, %arg2 : memref<16x16xi32, "L1">) -> () {
%0 = arith.constant 0 : i32
%1 = arith.constant 0 : i32
Expand Down
Loading