Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 22, 2024
1 parent b101cb5 commit 45127b1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
22 changes: 14 additions & 8 deletions compiler/transforms/snax_copy_to_dma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xdsl.dialects import arith, builtin, func, memref, scf
from xdsl.dialects import arith, builtin, func, scf
from xdsl.dialects.arith import Addi, Constant, Muli
from xdsl.dialects.builtin import IndexType, IntegerType, NoneAttr
from xdsl.dialects.memref import CopyOp, Dim, ExtractAlignedPointerAsIndexOp, MemRefType
Expand Down Expand Up @@ -278,17 +278,23 @@ class SNAXCopyToDMA(ModulePass):
name = "snax-copy-to-dma"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
contains_copies = any(
isinstance(op_in_module, memref.CopyOp) for op_in_module in op.walk()
)

if contains_copies:
PatternRewriteWalker(Match1DDMA()).rewrite_module(op)
PatternRewriteWalker(TransformDMA()).rewrite_module(op)
PatternRewriteWalker(Match1DDMA()).rewrite_module(op)
if any(
isinstance(op_in_module, func.Call)
and op_in_module.callee.root_reference.data == "snax_dma_1d_transfer"
for op_in_module in op.walk()
):
func_decl = func.FuncOp.external(
"snax_dma_1d_transfer", 3 * [builtin.IndexType()], []
)
SymbolTable.insert_or_update(op, func_decl)

PatternRewriteWalker(TransformDMA()).rewrite_module(op)
if any(
isinstance(op_in_module, func.Call)
and op_in_module.callee.root_reference.data == "snax_dma_2d_transfer"
for op_in_module in op.walk()
):
func_decl = func.FuncOp.external(
"snax_dma_2d_transfer", 6 * [builtin.IndexType()], []
)
Expand Down
9 changes: 8 additions & 1 deletion tests/filecheck/transforms/copy_to_dma.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<?x?xi32>, memref<?x?xi32>) -> (), "sym_visibility" = "public"}> ({
//CHECK-NEXT: ^0(%arg0 : memref<?x?xi32>, %arg1 : memref<?x?xi32>):
//CHECK-NEXT: "memref.copy"(%arg0, %arg1) : (memref<?x?xi32>, memref<?x?xi32>) -> ()
//CHECK-NEXT: %0 = "arith.constant"() <{"value" = 0 : index}> : () -> index
//CHECK-NEXT: %1 = "memref.dim"(%arg0, %0) : (memref<?x?xi32>, index) -> index
//CHECK-NEXT: %2 = "arith.constant"() <{"value" = 1 : index}> : () -> index
//CHECK-NEXT: %3 = "memref.dim"(%arg0, %2) : (memref<?x?xi32>, index) -> index
//CHECK-NEXT: %4 = "arith.muli"(%1, %3) : (index, index) -> index
//CHECK-NEXT: %5 = "memref.extract_aligned_pointer_as_index"(%arg0) : (memref<?x?xi32>) -> index
//CHECK-NEXT: %6 = "memref.extract_aligned_pointer_as_index"(%arg1) : (memref<?x?xi32>) -> index
//CHECK-NEXT: "func.call"(%5, %6, %4) <{"callee" = @snax_dma_1d_transfer}> : (index, index, index) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_dma_1d_transfer", "function_type" = (index, index, index) -> (), "sym_visibility" = "private"}> ({
Expand Down

0 comments on commit 45127b1

Please sign in to comment.