Skip to content

Commit

Permalink
new bufferization strategy (#295)
Browse files Browse the repository at this point in the history
* new bufferization strategy

* merge stream-bufferize and snax-bufferize

* delete stream bufferize

* formatting

* change gemm kernel

* update kernels

* update

* change nb cores

* update Makefile

* typo
  • Loading branch information
jorendumoulin authored Dec 6, 2024
1 parent e7c4c34 commit d195146
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 142 deletions.
2 changes: 1 addition & 1 deletion benchmarks/dense_matmul/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ else
REMOVE_MEMREF_COPY=
endif

SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,stream-bufferize,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},convert-stream-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space
SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=3},convert-stream-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space

GEN_DATA_OPTS += --m=${SIZE_M}
GEN_DATA_OPTS += --n=${SIZE_N}
Expand Down
2 changes: 0 additions & 2 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from compiler.transforms.snax_copy_to_dma import SNAXCopyToDMA
from compiler.transforms.snax_lower_mcycle import SNAXLowerMCycle
from compiler.transforms.snax_to_func import SNAXToFunc
from compiler.transforms.stream_bufferize import StreamBufferize
from compiler.transforms.test.debug_to_func import DebugToFuncPass
from compiler.transforms.test.insert_debugs import InsertDebugPass
from compiler.transforms.test.test_add_mcycle_around_launch import AddMcycleAroundLaunch
Expand Down Expand Up @@ -122,7 +121,6 @@ def __init__(
super().register_pass(PreprocessMLPerfTiny.name, lambda: PreprocessMLPerfTiny)
super().register_pass(AddMcycleAroundLaunch.name, lambda: AddMcycleAroundLaunch)
super().register_pass(ConvertLinalgToStream.name, lambda: ConvertLinalgToStream)
super().register_pass(StreamBufferize.name, lambda: StreamBufferize)
super().register_pass(SnaxBufferize.name, lambda: SnaxBufferize)
super().register_pass(FuseStreamingRegions.name, lambda: FuseStreamingRegions)
super().register_pass(AllocToGlobalPass.name, lambda: AllocToGlobalPass)
Expand Down
76 changes: 74 additions & 2 deletions compiler/transforms/snax_bufferize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,71 @@
from dataclasses import dataclass, field

from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects import bufferization, builtin
from xdsl.ir import Operation, OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.transforms.mlir_opt import MLIROptPass

from compiler.dialects import stream


@dataclass
class BufferizeStreamingRegion(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: stream.StreamingRegionOp, rewriter: PatternRewriter
) -> None:
# check for operands that need to be bufferized:
operands_to_buffer = tuple(
operand
for operand in op.operands
if isinstance(operand.type, builtin.TensorType)
)

# if not tensor operands, return
if not operands_to_buffer:
return

# for every unique input, make sure the tensor is the result
# of a to_tensor operation and store the original memref
tensor_to_memrefs: dict[SSAValue, SSAValue] = {}

for operand in set(operands_to_buffer):
if not isinstance(operand, OpResult):
return
if not isinstance(to_tensor_op := operand.op, bufferization.ToTensorOp):
return
tensor_to_memrefs[operand] = to_tensor_op.memref

new_op = stream.StreamingRegionOp(
inputs=[tensor_to_memrefs[input] for input in op.inputs],
outputs=[tensor_to_memrefs[output] for output in op.outputs],
patterns=op.patterns,
body=rewriter.move_region_contents_to_new_regions(op.body),
accelerator=op.accelerator,
)

# for every output, create a bufferization.to_tensor op
memref_to_tensors: dict[SSAValue, Operation] = {}
new_results: tuple[SSAValue, ...] = ()

for output in new_op.outputs:
to_tensor_op = bufferization.ToTensorOp(output, restrict=True)
memref_to_tensors[output] = to_tensor_op
new_results += to_tensor_op.results

# replace the old operation
rewriter.replace_matched_op(
(new_op,) + tuple(memref_to_tensors.values()),
new_results,
)


@dataclass(frozen=True)
class SnaxBufferize(ModulePass):
Expand All @@ -15,13 +76,24 @@ class SnaxBufferize(ModulePass):

mlir_bufferization_pass = MLIROptPass(
arguments=(
"--one-shot-bufferize=bufferize-function-boundaries allow-return-allocs-from-loops"
"--one-shot-bufferize=bufferize-function-boundaries allow-return-allocs-from-loops allow-unknown-ops"
+ " function-boundary-type-conversion=identity-layout-map",
"--mlir-print-op-generic",
"--mlir-print-local-scope",
"--allow-unregistered-dialect",
)
)

mlir_canonicalization_pass = MLIROptPass(
arguments=(
"--canonicalize",
"--mlir-print-op-generic",
"--mlir-print-local-scope",
"--allow-unregistered-dialect",
)
)

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
self.mlir_bufferization_pass.apply(ctx, op)
PatternRewriteWalker(BufferizeStreamingRegion()).rewrite_module(op)
self.mlir_canonicalization_pass.apply(ctx, op)
81 changes: 0 additions & 81 deletions compiler/transforms/stream_bufferize.py

This file was deleted.

2 changes: 1 addition & 1 deletion kernels/gemm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ MLIRPREPROCFLAGS += --mlir-print-local-scope
$(MLIROPT) $(MLIRPREPROCFLAGS) -o $@ $<


SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,stream-bufferize,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout,realize-memref-casts,insert-sync-barrier,dispatch-regions{nb_cores=2},convert-stream-to-snax-stream,convert-linalg-to-accfg,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,clear-memory-space
SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout,realize-memref-casts,insert-sync-barrier,dispatch-regions{nb_cores=2},convert-stream-to-snax-stream,convert-linalg-to-accfg,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,clear-memory-space

CFLAGS += -std=gnu11
CFLAGS += -Wall -Wextra
Expand Down
33 changes: 33 additions & 0 deletions tests/filecheck/transforms/snax-bufferize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: snax-opt %s -p snax-bufferize | filecheck %s
// TODO: add bufferization.clone to xDSL

"func.func"() <{"sym_name" = "test", "function_type" = (tensor<16x16xi8>) -> tensor<16x16xi32>}> ({
^0(%arg0 : tensor<16x16xi8>):
%empty = "tensor.empty"() : () -> tensor<16x16xi32>
%0 = "stream.streaming_region"(%arg0, %arg0, %empty) <{"patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
%4 = "stream.generic"(%1, %2) ({
^1(%in : i8, %in_1 : i8):
%5 = "test.op"(%in, %in_1) : (i8, i8) -> i32
stream.yield %5 : i32
}) : (!stream.stream<i8>, !stream.stream<i8>) -> !stream.stream<i32>
stream.yield %4 : !stream.stream<i32>
}) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>
"func.return"(%0) : (tensor<16x16xi32>) -> ()
}) : () -> ()

// CHECK: builtin.module {
// CHECK-NEXT: func.func @test(%arg0 : memref<16x16xi8>) -> memref<16x16xi32> {
// CHECK-NEXT: %0 = memref.alloc() {"alignment" = 64 : i64} : memref<16x16xi32>
// CHECK-NEXT: "stream.streaming_region"(%arg0, %arg0, %0) <{"accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>, "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>]}> ({
// CHECK-NEXT: ^0(%arg1 : !stream.stream<i8>, %arg2 : !stream.stream<i8>, %arg3 : !stream.stream<i32>):
// CHECK-NEXT: %1 = "stream.generic"(%arg1, %arg2) ({
// CHECK-NEXT: ^1(%arg4 : i8, %arg5 : i8):
// CHECK-NEXT: %2 = "test.op"(%arg4, %arg5) : (i8, i8) -> i32
// CHECK-NEXT: stream.yield %2 : i32
// CHECK-NEXT: }) : (!stream.stream<i8>, !stream.stream<i8>) -> !stream.stream<i32>
// CHECK-NEXT: stream.yield %1 : !stream.stream<i32>
// CHECK-NEXT: }) : (memref<16x16xi8>, memref<16x16xi8>, memref<16x16xi32>) -> ()
// CHECK-NEXT: func.return %0 : memref<16x16xi32>
// CHECK-NEXT: }
// CHECK-NEXT: }
55 changes: 0 additions & 55 deletions tests/filecheck/transforms/stream-bufferize.mlir

This file was deleted.

0 comments on commit d195146

Please sign in to comment.