From f1669ed938b359483c7dbd31680cd987318c8902 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Tue, 24 Sep 2024 08:34:33 +0200 Subject: [PATCH] transforms (test-add-mcycle-around-launch) init pass --- compiler/tools/snax_opt_main.py | 2 + .../test/test_add_mcycle_around_loop.py | 40 +++++++++++++++++++ .../test/test-add-mcycle-around-launch.mlir | 33 +++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 compiler/transforms/test/test_add_mcycle_around_loop.py create mode 100644 tests/filecheck/transforms/test/test-add-mcycle-around-launch.mlir diff --git a/compiler/tools/snax_opt_main.py b/compiler/tools/snax_opt_main.py index c6ee51b1..68910430 100644 --- a/compiler/tools/snax_opt_main.py +++ b/compiler/tools/snax_opt_main.py @@ -43,6 +43,7 @@ from compiler.transforms.stream_snaxify import StreamSnaxify from compiler.transforms.test.debug_to_func import DebugToFuncPass from compiler.transforms.test.insert_debugs import InsertDebugPass +from compiler.transforms.test.test_add_mcycle_around_loop import AddMcycleAroundLaunch from compiler.transforms.test_add_mcycle_around_loop import AddMcycleAroundLoopPass from compiler.transforms.test_remove_memref_copy import RemoveMemrefCopyPass @@ -117,6 +118,7 @@ def __init__( super().register_pass(InsertDebugPass.name, lambda: InsertDebugPass) super().register_pass(DebugToFuncPass.name, lambda: DebugToFuncPass) super().register_pass(PreprocessMLPerfTiny.name, lambda: PreprocessMLPerfTiny) + super().register_pass(AddMcycleAroundLaunch.name, lambda: AddMcycleAroundLaunch) # arg handling arg_parser = argparse.ArgumentParser(description=description) diff --git a/compiler/transforms/test/test_add_mcycle_around_loop.py b/compiler/transforms/test/test_add_mcycle_around_loop.py new file mode 100644 index 00000000..edc0ecaa --- /dev/null +++ b/compiler/transforms/test/test_add_mcycle_around_loop.py @@ -0,0 +1,40 @@ +from xdsl.context import MLContext +from xdsl.dialects import builtin +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.rewriter import InsertPoint + +from compiler.dialects import accfg, snax + + +class InsertBeforeLaunch(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: accfg.LaunchOp, rewriter: PatternRewriter, /): + rewriter.insert_op(snax.MCycleOp(), InsertPoint.before(op)) + + +class InsertAfterAwait(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: accfg.AwaitOp, rewriter: PatternRewriter, /): + rewriter.insert_op(snax.MCycleOp(), InsertPoint.after(op)) + + +class AddMcycleAroundLaunch(ModulePass): + """ + Pass to insert an mcycle op before all accfg.launches and after all accfg.awaits + """ + + name = "test-add-mcycle-around-launch" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker( + InsertBeforeLaunch(), apply_recursively=False + ).rewrite_module(op) + PatternRewriteWalker( + InsertAfterAwait(), apply_recursively=False + ).rewrite_module(op) diff --git a/tests/filecheck/transforms/test/test-add-mcycle-around-launch.mlir b/tests/filecheck/transforms/test/test-add-mcycle-around-launch.mlir new file mode 100644 index 00000000..2474efd4 --- /dev/null +++ b/tests/filecheck/transforms/test/test-add-mcycle-around-launch.mlir @@ -0,0 +1,33 @@ +// RUN: snax-opt --split-input-file -p test-add-mcycle-around-launch %s | filecheck %s + +"accfg.accelerator"() <{ + name = @acc1, + fields = {A=0x3c0, B=0x3c1}, + launch_fields = {launch=0x3cf}, + barrier = 0x7c3 +}> : () -> () + +func.func @test() { + %one, %two = "test.op"() : () -> (i32, i32) + %zero = arith.constant 0 : i32 + %state = "accfg.setup"(%one, %two) <{ + param_names = ["A", "B"], + accelerator = "acc1", + operandSegmentSizes = array + }> {"test_attr" = 100 : i64} : (i32, i32) -> !accfg.state<"acc1"> + + %token = "accfg.launch"(%zero, %state) <{ + param_names = ["launch"], + accelerator = "acc1" + }>: (i32, !accfg.state<"acc1">) -> !accfg.token<"acc1"> + + "accfg.await"(%token) : (!accfg.token<"acc1">) -> () + + func.return +} + + +// CHECK: snax.mcycle +// CHECK-NEXT: accfg.launch +// CHECK-NEXT: accfg.await +// CHECK-NEXT: snax.mcycle