From d663036018b6fe985ab3bc8acdadf920690947e6 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Mon, 6 Jan 2025 12:11:57 +0100 Subject: [PATCH] remove linalg-to-library-call pass (#329) * remove linalg-to-library-call pass * update pyproject.toml * lockfile --- compiler/tools/snax_opt_main.py | 2 - compiler/transforms/linalg_to_library_call.py | 87 ------------------- kernels/alloc/Snakefile | 1 - kernels/simple_copy/Snakefile | 1 - kernels/transform_copy/Snakefile | 1 - pixi.lock | 2 +- pyproject.toml | 1 - runtime/Makefile.rules | 2 +- .../transforms/linalg-to-library-call.mlir | 29 ------- util/snake/configs.py | 1 - 10 files changed, 2 insertions(+), 125 deletions(-) delete mode 100644 compiler/transforms/linalg_to_library_call.py delete mode 100644 tests/filecheck/transforms/linalg-to-library-call.mlir diff --git a/compiler/tools/snax_opt_main.py b/compiler/tools/snax_opt_main.py index 7141fe7c..29ccc823 100644 --- a/compiler/tools/snax_opt_main.py +++ b/compiler/tools/snax_opt_main.py @@ -32,7 +32,6 @@ from compiler.transforms.fuse_streaming_regions import FuseStreamingRegions from compiler.transforms.insert_accfg_op import InsertAccOp from compiler.transforms.insert_sync_barrier import InsertSyncBarrier -from compiler.transforms.linalg_to_library_call import LinalgToLibraryCall from compiler.transforms.memref_to_snax import MemrefToSNAX from compiler.transforms.realize_memref_casts import RealizeMemrefCastsPass from compiler.transforms.reuse_memref_allocs import ReuseMemrefAllocs @@ -79,7 +78,6 @@ def __init__( self.ctx.load_dialect(Debug) self.ctx.load_dialect(Stream) super().register_pass(DispatchKernels.name, lambda: DispatchKernels) - super().register_pass(LinalgToLibraryCall.name, lambda: LinalgToLibraryCall) super().register_pass(SetMemorySpace.name, lambda: SetMemorySpace) super().register_pass(SetMemoryLayout.name, lambda: SetMemoryLayout) super().register_pass(InsertAccOp.name, lambda: InsertAccOp) diff --git a/compiler/transforms/linalg_to_library_call.py b/compiler/transforms/linalg_to_library_call.py deleted file mode 100644 index 76854f2e..00000000 --- a/compiler/transforms/linalg_to_library_call.py +++ /dev/null @@ -1,87 +0,0 @@ -from xdsl.context import MLContext -from xdsl.dialects import builtin, func, linalg -from xdsl.dialects.builtin import MemRefType -from xdsl.dialects.memref import CastOp -from xdsl.passes import ModulePass -from xdsl.pattern_rewriter import ( - PatternRewriter, -) -from xdsl.traits import SymbolTable - - -class LinalgToLibraryCall(ModulePass): - """ - This pass detects linalg operations with an external library call, and - replaces them with a function call and definition. - """ - - name = "linalg-to-library-call" - - def apply(self, ctx: MLContext, module: builtin.ModuleOp) -> None: - rewriter = PatternRewriter(module) - - for op in module.walk(): - # Op must be linalg generic - if not isinstance(op, linalg.GenericOp): - continue - - if op.library_call is None: - continue - - # all memref arguments must be cast to a new memref with a dynamic shape - # to avoid type mismatching for multiple function calls with different - # argument shapes - cast_ops_to_insert = [] - operands = [] - for operand in op.operands: - if isinstance(operand.type, MemRefType): - new_type = MemRefType( - operand.type.element_type, - [-1] * len(operand.type.shape), - operand.type.layout, - operand.type.memory_space, - ) - cast = CastOp.get(operand, new_type) - cast_ops_to_insert.append(cast) - operands.append(cast) - else: - operands.append(operand) - - func_call = func.CallOp(op.library_call.data, operands, []) - - # Replace op with function call - rewriter.replace_op(op, [*cast_ops_to_insert, func_call]) - - # Insert external function definition - - # the memref arguments must be changed to dynamic shapes - # for the input types - input_types = [arg.type for arg in op.inputs] - for i, input_type in enumerate(input_types): - if isinstance(input_type, MemRefType): - input_types[i] = MemRefType( - input_type.element_type, - [-1] * len(input_type.shape), - input_type.layout, - input_type.memory_space, - ) - # do the same for output types - output_types = [res.type for res in op.outputs] - for i, output_type in enumerate(output_types): - if isinstance(output_type, MemRefType): - output_types[i] = MemRefType( - output_type.element_type, - [-1] * len(output_type.shape), - output_type.layout, - output_type.memory_space, - ) - - # both inputs and outputs are passed as inputs for the external function - # (pass allocated output memref by reference to the external function) - func_op = func.FuncOp.external( - func_call.callee.string_value(), - [*input_types, *output_types], - [], - ) - - SymbolTable.insert_or_update(module, func_op) diff --git a/kernels/alloc/Snakefile b/kernels/alloc/Snakefile index 92c7216c..6981bdfa 100644 --- a/kernels/alloc/Snakefile +++ b/kernels/alloc/Snakefile @@ -11,7 +11,6 @@ config["snaxoptflags"] = ",".join( "reuse-memref-allocs", "insert-sync-barrier", "dispatch-regions", - "linalg-to-library-call", "snax-copy-to-dma", "memref-to-snax", "snax-to-func", diff --git a/kernels/simple_copy/Snakefile b/kernels/simple_copy/Snakefile index bc6814b8..eb222ff9 100644 --- a/kernels/simple_copy/Snakefile +++ b/kernels/simple_copy/Snakefile @@ -10,7 +10,6 @@ config["snaxoptflags"] = ",".join( "reuse-memref-allocs", "insert-sync-barrier", "dispatch-regions", - "linalg-to-library-call", "snax-copy-to-dma", "memref-to-snax", "snax-to-func", diff --git a/kernels/transform_copy/Snakefile b/kernels/transform_copy/Snakefile index 57258ab0..8ad7e6b7 100644 --- a/kernels/transform_copy/Snakefile +++ b/kernels/transform_copy/Snakefile @@ -9,7 +9,6 @@ config["snaxoptflags"] = ",".join( "realize-memref-casts", "reuse-memref-allocs", "insert-sync-barrier", - "linalg-to-library-call", "snax-copy-to-dma", "memref-to-snax", "snax-to-func", diff --git a/pixi.lock b/pixi.lock index 257faeb7..4cf22387 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2095,7 +2095,7 @@ packages: - pypi: . name: snax-mlir version: 0.2.2 - sha256: 0c91ec881341fbb891986a6a0647ae35f2dc91ef9939ae0b1d9b31f7d5c4263b + sha256: 55b922fb3119a196e110ee98278db73d1ab26ef7c3e03cbd497abac0c50d7bfa requires_dist: - xdsl @ git+https://github.com/xdslproject/xdsl.git@d72f46d92ec4b03ae05b91e70d75f93735e94393 - pre-commit ; extra == 'dev' diff --git a/pyproject.toml b/pyproject.toml index 9557c918..39b02239 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ typeCheckingMode = "strict" "compiler/transforms/accfg_dedup.py", "compiler/transforms/convert_linalg_to_accfg.py", "compiler/transforms/frontend/preprocess_mlperf_tiny.py", - "compiler/transforms/linalg_to_library_call.py", "compiler/transforms/reuse_memref_allocs.py", "compiler/transforms/set_memory_layout.py", "compiler/transforms/snax_copy_to_dma.py", diff --git a/runtime/Makefile.rules b/runtime/Makefile.rules index 6ddb7270..0ec89e3d 100644 --- a/runtime/Makefile.rules +++ b/runtime/Makefile.rules @@ -106,7 +106,7 @@ MLIRPREPROC3FLAGS += --mlir-print-local-scope # SNAX opt -SNAXOPTFLAGS = -p dispatch-kernels,set-memory-space,set-memory-layout,realize-memref-casts,reuse-memref-allocs,insert-sync-barrier,dispatch-regions,linalg-to-library-call,snax-copy-to-dma,memref-to-snax,snax-to-func,clear-memory-space +SNAXOPTFLAGS = -p dispatch-kernels,set-memory-space,set-memory-layout,realize-memref-casts,reuse-memref-allocs,insert-sync-barrier,dispatch-regions,snax-copy-to-dma,memref-to-snax,snax-to-func,clear-memory-space %.snax-opt.mlir: %.preprocfinal.mlir $(SNAXOPT) $(SNAXOPTFLAGS) --print-op-generic -o $@ $< diff --git a/tests/filecheck/transforms/linalg-to-library-call.mlir b/tests/filecheck/transforms/linalg-to-library-call.mlir deleted file mode 100644 index 4b3ad5f3..00000000 --- a/tests/filecheck/transforms/linalg-to-library-call.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: ./compiler/snax-opt %s -p linalg-to-library-call --allow-unregistered-dialect --print-op-generic | filecheck %s - -"builtin.module"() ({ - %0, %1, %2, %3 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>, memref<64xi32>) - "linalg.generic"(%0, %1, %2) <{"library_call" = "snax_hwpe_mult", "indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type], "operandSegmentSizes" = array}> ({ - ^0(%arg0 : i32, %arg1 : i32, %arg2 : i32): - %4 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32 - "linalg.yield"(%4) : (i32) -> () - }) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () - "linalg.generic"(%1, %2, %3) <{"library_call" = "snax_hwpe_mult", "indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type], "operandSegmentSizes" = array}> ({ - ^0(%arg0 : i32, %arg1 : i32, %arg2 : i32): - %5 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32 - "linalg.yield"(%5) : (i32) -> () - }) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () -}) : () -> () - -//CHECK: "builtin.module"() ({ -//CHECK-NEXT: %0, %1, %2, %3 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>, memref<64xi32>) -//CHECK-NEXT: %4 = "memref.cast"(%0) : (memref<64xi32>) -> memref -//CHECK-NEXT: %5 = "memref.cast"(%1) : (memref<64xi32>) -> memref -//CHECK-NEXT: %6 = "memref.cast"(%2) : (memref<64xi32>) -> memref -//CHECK-NEXT: "func.call"(%4, %5, %6) <{"callee" = @snax_hwpe_mult}> : (memref, memref, memref) -> () -//CHECK-NEXT: %7 = "memref.cast"(%1) : (memref<64xi32>) -> memref -//CHECK-NEXT: %8 = "memref.cast"(%2) : (memref<64xi32>) -> memref -//CHECK-NEXT: %9 = "memref.cast"(%3) : (memref<64xi32>) -> memref -//CHECK-NEXT: "func.call"(%7, %8, %9) <{"callee" = @snax_hwpe_mult}> : (memref, memref, memref) -> () -//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref, memref, memref) -> (), "sym_visibility" = "private"}> ({ -//CHECK-NEXT: }) : () -> () -//CHECK-NEXT: }) : () -> () diff --git a/util/snake/configs.py b/util/snake/configs.py index f6eac9a5..adaff8af 100644 --- a/util/snake/configs.py +++ b/util/snake/configs.py @@ -66,7 +66,6 @@ def get_mlperf_tiny_config(): "reuse-memref-allocs", "insert-sync-barrier", "dispatch-regions", - "linalg-to-library-call", "snax-copy-to-dma", "memref-to-snax", "snax-to-func",