Skip to content

Commit

Permalink
transforms (test-add-mcycle-around-launch) init pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Sep 24, 2024
1 parent 7cc59b1 commit f1669ed
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from compiler.transforms.stream_snaxify import StreamSnaxify
from compiler.transforms.test.debug_to_func import DebugToFuncPass
from compiler.transforms.test.insert_debugs import InsertDebugPass
from compiler.transforms.test.test_add_mcycle_around_loop import AddMcycleAroundLaunch
from compiler.transforms.test_add_mcycle_around_loop import AddMcycleAroundLoopPass
from compiler.transforms.test_remove_memref_copy import RemoveMemrefCopyPass

Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
super().register_pass(InsertDebugPass.name, lambda: InsertDebugPass)
super().register_pass(DebugToFuncPass.name, lambda: DebugToFuncPass)
super().register_pass(PreprocessMLPerfTiny.name, lambda: PreprocessMLPerfTiny)
super().register_pass(AddMcycleAroundLaunch.name, lambda: AddMcycleAroundLaunch)

# arg handling
arg_parser = argparse.ArgumentParser(description=description)
Expand Down
40 changes: 40 additions & 0 deletions compiler/transforms/test/test_add_mcycle_around_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint

from compiler.dialects import accfg, snax


class InsertBeforeLaunch(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: accfg.LaunchOp, rewriter: PatternRewriter, /):
rewriter.insert_op(snax.MCycleOp(), InsertPoint.before(op))


class InsertAfterAwait(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: accfg.AwaitOp, rewriter: PatternRewriter, /):
rewriter.insert_op(snax.MCycleOp(), InsertPoint.after(op))


class AddMcycleAroundLaunch(ModulePass):
"""
Pass to insert an mcycle op before all accfg.launches and after all accfg.awaits
"""

name = "test-add-mcycle-around-launch"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
InsertBeforeLaunch(), apply_recursively=False
).rewrite_module(op)
PatternRewriteWalker(
InsertAfterAwait(), apply_recursively=False
).rewrite_module(op)
33 changes: 33 additions & 0 deletions tests/filecheck/transforms/test/test-add-mcycle-around-launch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: snax-opt --split-input-file -p test-add-mcycle-around-launch %s | filecheck %s

"accfg.accelerator"() <{
name = @acc1,
fields = {A=0x3c0, B=0x3c1},
launch_fields = {launch=0x3cf},
barrier = 0x7c3
}> : () -> ()

func.func @test() {
%one, %two = "test.op"() : () -> (i32, i32)
%zero = arith.constant 0 : i32
%state = "accfg.setup"(%one, %two) <{
param_names = ["A", "B"],
accelerator = "acc1",
operandSegmentSizes = array<i32: 2, 0>
}> {"test_attr" = 100 : i64} : (i32, i32) -> !accfg.state<"acc1">

%token = "accfg.launch"(%zero, %state) <{
param_names = ["launch"],
accelerator = "acc1"
}>: (i32, !accfg.state<"acc1">) -> !accfg.token<"acc1">

"accfg.await"(%token) : (!accfg.token<"acc1">) -> ()

func.return
}


// CHECK: snax.mcycle
// CHECK-NEXT: accfg.launch
// CHECK-NEXT: accfg.await
// CHECK-NEXT: snax.mcycle

0 comments on commit f1669ed

Please sign in to comment.