diff --git a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir similarity index 99% rename from tests/filecheck/transforms/stencil-to-csl-stencil.mlir rename to tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index f4e6927c6d..d7b53644a0 100644 --- a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -1,4 +1,4 @@ -// RUN: xdsl-opt %s -p "stencil-to-csl-stencil{num_chunks=2}" | filecheck %s +// RUN: xdsl-opt %s -p "convert-stencil-to-csl-stencil{num_chunks=2}" | filecheck %s builtin.module { // CHECK-NEXT: builtin.module { diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index cc9e7047a1..0cea8e5dde 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -95,46 +95,46 @@ builtin.module { // CHECK-NEXT: "csl.export"(%36) <{"var_name" = "arg1", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () // CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () // CHECK-NEXT: csl.func @gauss_seidel_func() { -// CHECK-NEXT: %arg4 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> +// CHECK-NEXT: %accumulator = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: %37 = arith.constant 2 : i16 -// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @chunk_reduce_cb0}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @post_process_cb0}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb0}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb0}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: "csl.member_call"(%34, %arg0, %37, %38, %39) <{"field" = "communicate"}> : (!csl.imported_module, memref<512xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @chunk_reduce_cb0(%40 : i16) { -// CHECK-NEXT: %arg3 = arith.index_cast %40 : i16 to index -// CHECK-NEXT: %41 = arith.constant 1 : i16 -// CHECK-NEXT: %42 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %43 = "csl.member_call"(%34, %42, %41) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %44 = builtin.unrealized_conversion_cast %43 : !csl to memref<255xf32> -// CHECK-NEXT: %45 = arith.constant 1 : i16 -// CHECK-NEXT: %46 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %47 = "csl.member_call"(%34, %46, %45) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %48 = builtin.unrealized_conversion_cast %47 : !csl to memref<255xf32> -// CHECK-NEXT: %49 = arith.constant 1 : i16 -// CHECK-NEXT: %50 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %51 = "csl.member_call"(%34, %50, %49) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %52 = builtin.unrealized_conversion_cast %51 : !csl to memref<255xf32> -// CHECK-NEXT: %53 = arith.constant 1 : i16 -// CHECK-NEXT: %54 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %55 = "csl.member_call"(%34, %54, %53) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %56 = builtin.unrealized_conversion_cast %55 : !csl to memref<255xf32> -// CHECK-NEXT: %57 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> -// CHECK-NEXT: "csl.fadds"(%57, %56, %52) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%57, %57, %48) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%57, %57, %44) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () +// CHECK-NEXT: csl.func @receive_chunk_cb0(%offset : i16) { +// CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index +// CHECK-NEXT: %40 = arith.constant 1 : i16 +// CHECK-NEXT: %41 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %42 = "csl.member_call"(%34, %41, %40) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %43 = builtin.unrealized_conversion_cast %42 : !csl to memref<255xf32> +// CHECK-NEXT: %44 = arith.constant 1 : i16 +// CHECK-NEXT: %45 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %46 = "csl.member_call"(%34, %45, %44) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %47 = builtin.unrealized_conversion_cast %46 : !csl to memref<255xf32> +// CHECK-NEXT: %48 = arith.constant 1 : i16 +// CHECK-NEXT: %49 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %50 = "csl.member_call"(%34, %49, %48) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %51 = builtin.unrealized_conversion_cast %50 : !csl to memref<255xf32> +// CHECK-NEXT: %52 = arith.constant 1 : i16 +// CHECK-NEXT: %53 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %54 = "csl.member_call"(%34, %53, %52) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %55 = builtin.unrealized_conversion_cast %54 : !csl to memref<255xf32> +// CHECK-NEXT: %56 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%56, %55, %51) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%56, %56, %47) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%56, %56, %43) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @post_process_cb0() { -// CHECK-NEXT: %58 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %59 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%arg4, %arg4, %59) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%arg4, %arg4, %58) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %60 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%arg4, %arg4, %60) : (memref<510xf32>, memref<510xf32>, f32) -> () -// CHECK-NEXT: %61 = memref.subview %arg1[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "memref.copy"(%arg4, %61) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: csl.func @done_exchange_cb0() { +// CHECK-NEXT: %57 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %58 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %58) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %57) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %59 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %59) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %60 = memref.subview %arg1[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator, %60) : (memref<510xf32>, memref<510xf32>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () diff --git a/tests/filecheck/transforms/csl_wrapper_to_csl.mlir b/tests/filecheck/transforms/lower-csl-wrapper.mlir similarity index 99% rename from tests/filecheck/transforms/csl_wrapper_to_csl.mlir rename to tests/filecheck/transforms/lower-csl-wrapper.mlir index d4515a3bb4..c890f41e22 100644 --- a/tests/filecheck/transforms/csl_wrapper_to_csl.mlir +++ b/tests/filecheck/transforms/lower-csl-wrapper.mlir @@ -1,4 +1,4 @@ -// RUN: xdsl-opt -p csl-wrapper-to-csl %s | filecheck --match-full-lines %s +// RUN: xdsl-opt -p lower-csl-wrapper %s | filecheck --match-full-lines %s builtin.module { "csl_wrapper.module"() <{ diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index 80b88042e5..e8bd1b88eb 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -156,10 +156,10 @@ class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): @classmethod def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: from xdsl.transforms.canonicalization_patterns.csl_stencil import ( - RedundantIterArgInitialisation, + RedundantAccumulatorInitialisation, ) - return (RedundantIterArgInitialisation(),) + return (RedundantAccumulatorInitialisation(),) @irdl_op_definition @@ -169,15 +169,15 @@ class ApplyOp(IRDLOperation): with a `stencil.apply` (a stencil function plus parameters and applies the stencil function to the output temp). As communication may be done in chunks, this operation provides two regions for computation: - - the `chunk_reduce` region to reduce a chunk of data received from several neighbours to one chunk of data. + - the `receive_chunk` region to reduce a chunk of data received from several neighbours to one chunk of data. this region is invoked once per communicated chunks and effectively acts as a loop body. - It uses `iter_arg` to concatenate the chunks - - the `post_process` region (invoked once when communication has finished) that takes the concatenated - chunk of the `chunk_reduce` region and applies any further processing here - for instance, it may handle + It uses `accumulator` to concatenate the chunks + - the `done_exchange` region (invoked once when communication has finished) that takes the concatenated + chunk of the `receive_chunk` region and applies any further processing here - for instance, it may handle the computation of 'own' (non-communicated) or otherwise prefetched data Further fields: - - `communicated_stencil` - the stencil to communicate (send and receive) + - `field` - the stencil field to communicate (send and receive) - `args` - arguments to the stencil computation, may include other prefetched buffers - `topo` - as received from `csl_stencil.prefetch`/`dmp.swap` - `num_chunks` - number of chunks into which to slice the communication @@ -187,32 +187,32 @@ class ApplyOp(IRDLOperation): Function signatures: Before lowering (from `csl_stencil.prefetch` and `stencil.apply`): - %pref = csl_stencil.prefetch(%communicated_stencil : stencil.Temp) - stencil.apply( ..some args.. , %communicated_stencil, ..some more args.., %pref) + %pref = csl_stencil.prefetch(%field : stencil.Temp) + stencil.apply( ..some args.. , %field, ..some more args.., %pref) After lowering: - op: csl_stencil.apply(%communicated_stencil, %iter_arg, chunk_reduce_args..., post_process_args...) - chunk_reduce: block_args(slice of type(%pref), %offset, %iter_arg, args...) - post_process: block_args(%communicated_stencil, %iter_arg, args...) + op: csl_stencil.apply(%field, %accumulator, receive_chunk_args..., done_exchange_args...) + receive_chunk: block_args(slice of type(%pref), %offset, %accumulator, args...) + done_exchange: block_args(%field, %accumulator, args...) Note, that %pref can be dropped (as communication is done by the op rather than before the op), - and that a new %iter_arg is required, an empty tensor which is filled by `chunk_reduce` and - consumed by `post_process` + and that a new %accumulator is required, an empty tensor which is filled by `receive_chunk` and + consumed by `done_exchange` """ name = "csl_stencil.apply" - communicated_stencil = operand_def( + field = operand_def( base(stencil.StencilType[Attribute]) | base(memref.MemRefType[Attribute]) ) - iter_arg = operand_def(TensorType[Attribute] | memref.MemRefType[Attribute]) + accumulator = operand_def(TensorType[Attribute] | memref.MemRefType[Attribute]) args = var_operand_def(Attribute) dest = var_operand_def(stencil.FieldType | memref.MemRefType[Attribute]) - chunk_reduce = region_def() - post_process = region_def() + receive_chunk = region_def() + done_exchange = region_def() swaps = prop_def(builtin.ArrayAttr[ExchangeDeclarationAttr]) @@ -243,7 +243,7 @@ def print_arg(arg: SSAValue): printer.print("(") # args required by function signature, plus optional args for regions - args = [self.communicated_stencil, self.iter_arg, *self.args] + args = [self.field, self.accumulator, *self.args] printer.print_list(args, print_arg) if self.dest: @@ -259,9 +259,9 @@ def print_arg(arg: SSAValue): printer.print("> ") printer.print_op_attributes(self.attributes, print_keyword=True) printer.print("(") - printer.print_region(self.chunk_reduce, print_entry_block_args=True) + printer.print_region(self.receive_chunk, print_entry_block_args=True) printer.print(", ") - printer.print_region(self.post_process, print_entry_block_args=True) + printer.print_region(self.done_exchange, print_entry_block_args=True) printer.print(")") if self.bounds is not None: printer.print(" to ") @@ -298,9 +298,9 @@ def parse_args(): if attrs is not None: attrs = attrs.data parser.parse_punctuation("(") - chunk_reduce = parser.parse_region() + receive_chunk = parser.parse_region() parser.parse_punctuation(",") - post_process = parser.parse_region() + done_exchange = parser.parse_region() parser.parse_punctuation(")") if parser.parse_optional_keyword("to"): props["bounds"] = stencil.StencilBoundsAttr.new( @@ -309,7 +309,7 @@ def parse_args(): return cls( operands=[operands[0], operands[1], operands[2:], destinations], result_types=[result_types], - regions=[chunk_reduce, post_process], + regions=[receive_chunk, done_exchange], properties=props, attributes=attrs, ) @@ -317,15 +317,15 @@ def parse_args(): def verify_(self) -> None: # typecheck op arguments if ( - len(self.chunk_reduce.block.args) < 3 - or len(self.post_process.block.args) < 2 + len(self.receive_chunk.block.args) < 3 + or len(self.done_exchange.block.args) < 2 ): raise VerifyException("Missing required block args on region") op_args = ( - self.post_process.block.args[0], - self.chunk_reduce.block.args[2], - *self.chunk_reduce.block.args[3:], - *self.post_process.block.args[2:], + self.done_exchange.block.args[0], + self.receive_chunk.block.args[2], + *self.receive_chunk.block.args[3:], + *self.done_exchange.block.args[2:], ) for operand, argument in zip(self.operands, op_args): if operand.type != argument.type: @@ -335,36 +335,36 @@ def verify_(self) -> None: # typecheck required (only) block arguments assert isa( - self.iter_arg.type, TensorType[Attribute] | memref.MemRefType[Attribute] + self.accumulator.type, TensorType[Attribute] | memref.MemRefType[Attribute] ) - chunk_reduce_req_types = [ - type(self.iter_arg.type)( - self.iter_arg.type.get_element_type(), + chunk_region_req_types = [ + type(self.accumulator.type)( + self.accumulator.type.get_element_type(), ( len(self.swaps), - self.iter_arg.type.get_shape()[0] // self.num_chunks.value.data, + self.accumulator.type.get_shape()[0] // self.num_chunks.value.data, ), ), IndexType(), - self.iter_arg.type, + self.accumulator.type, ] - post_process_req_types = [ - self.communicated_stencil.type, - self.iter_arg.type, + done_exchange_req_types = [ + self.field.type, + self.accumulator.type, ] for arg, expected_type in zip( - self.chunk_reduce.block.args, chunk_reduce_req_types + self.receive_chunk.block.args, chunk_region_req_types ): if arg.type != expected_type: raise VerifyException( - f"Unexpected block argument type of chunk_reduce, got {arg.type} != {expected_type} at index {arg.index}" + f"Unexpected block argument type of receive_chunk, got {arg.type} != {expected_type} at index {arg.index}" ) for arg, expected_type in zip( - self.post_process.block.args, post_process_req_types + self.done_exchange.block.args, done_exchange_req_types ): if arg.type != expected_type: raise VerifyException( - f"Unexpected block argument type of post_process, got {arg.type} != {expected_type} at index {arg.index}" + f"Unexpected block argument type of done_exchange, got {arg.type} != {expected_type} at index {arg.index}" ) if (len(self.res) == 0) == (len(self.dest) == 0): @@ -393,7 +393,7 @@ def get_accesses(self) -> Iterable[stencil.AccessPattern]: field of the apply operation. """ # iterate over the block arguments - for arg in self.chunk_reduce.block.args + self.post_process.block.args: + for arg in self.receive_chunk.block.args + self.done_exchange.block.args: accesses: list[tuple[int, ...]] = [] # walk the uses of the argument for use in arg.uses: diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index 9f1d1af58f..e80c9d6b83 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -106,10 +106,10 @@ def get_csl_stencil_to_csl_wrapper(): return csl_stencil_to_csl_wrapper.CslStencilToCslWrapperPass - def get_csl_wrapper_to_csl(): - from xdsl.transforms import csl_wrapper_to_csl + def get_lower_csl_wrapper(): + from xdsl.transforms import lower_csl_wrapper - return csl_wrapper_to_csl.CslWrapperToCslPass + return lower_csl_wrapper.LowerCslWrapperPass def get_csl_wrapper_hoist_buffers(): from xdsl.transforms import csl_wrapper_hoist_buffers @@ -390,10 +390,10 @@ def get_stencil_tensorize_z_dimension(): return stencil_tensorize_z_dimension.StencilTensorizeZDimension - def get_stencil_to_csl_stencil(): - from xdsl.transforms import stencil_to_csl_stencil + def get_convert_stencil_to_csl_stencil(): + from xdsl.transforms import convert_stencil_to_csl_stencil - return stencil_to_csl_stencil.StencilToCslStencilPass + return convert_stencil_to_csl_stencil.ConvertStencilToCslStencilPass def get_stencil_unroll(): from xdsl.transforms import stencil_unroll @@ -436,12 +436,12 @@ def get_stencil_bufferize(): "convert-scf-to-openmp": get_convert_scf_to_openmp, "convert-scf-to-riscv-scf": get_convert_scf_to_riscv_scf, "convert-snitch-stream-to-snitch": get_convert_snitch_stream_to_snitch, + "convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil, "inline-snrt": get_convert_snrt_to_riscv, "convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir, "cse": get_cse, "csl-stencil-bufferize": get_csl_stencil_bufferize, "csl-stencil-to-csl-wrapper": get_csl_stencil_to_csl_wrapper, - "csl-wrapper-to-csl": get_csl_wrapper_to_csl, "csl-wrapper-hoist-buffers": get_csl_wrapper_hoist_buffers, "csl-stencil-handle-async-flow": get_csl_stencil_handle_async_flow, "dce": get_dce, @@ -457,6 +457,7 @@ def get_stencil_bufferize(): "linalg-to-csl": get_linalg_to_csl, "lower-affine": get_lower_affine, "lower-csl-stencil": get_lower_csl_stencil, + "lower-csl-wrapper": get_lower_csl_wrapper, "lower-hls": get_lower_hls, "lower-mpi": get_lower_mpi, "lower-riscv-func": get_lower_riscv_func, @@ -485,7 +486,6 @@ def get_stencil_bufferize(): "shape-inference": get_shape_inference, "stencil-storage-materialization": get_stencil_storage_materialization, "stencil-tensorize-z-dimension": get_stencil_tensorize_z_dimension, - "stencil-to-csl-stencil": get_stencil_to_csl_stencil, "stencil-unroll": get_stencil_unroll, "stencil-bufferize": get_stencil_bufferize, "test-lower-linalg-to-snitch": get_test_lower_linalg_to_snitch, diff --git a/xdsl/transforms/canonicalization_patterns/csl_stencil.py b/xdsl/transforms/canonicalization_patterns/csl_stencil.py index 4fe5896d85..362b1fc928 100644 --- a/xdsl/transforms/canonicalization_patterns/csl_stencil.py +++ b/xdsl/transforms/canonicalization_patterns/csl_stencil.py @@ -8,7 +8,7 @@ ) -class RedundantIterArgInitialisation(RewritePattern): +class RedundantAccumulatorInitialisation(RewritePattern): """ Removes redundant allocations of empty tensors with no uses other than passed as `iter_arg` to `csl_stencil.apply`. Prefer re-use where possible. @@ -18,16 +18,16 @@ class RedundantIterArgInitialisation(RewritePattern): def match_and_rewrite( self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter ) -> None: - if len(op.iter_arg.uses) > 1: + if len(op.accumulator.uses) > 1: return next_apply = op while (next_apply := next_apply.next_op) is not None: if ( isinstance(next_apply, csl_stencil.ApplyOp) - and len(next_apply.iter_arg.uses) == 1 - and isinstance(next_apply.iter_arg, OpResult) - and isinstance(next_apply.iter_arg.op, tensor.EmptyOp) - and op.iter_arg.type == next_apply.iter_arg.type + and len(next_apply.accumulator.uses) == 1 + and isinstance(next_apply.accumulator, OpResult) + and isinstance(next_apply.accumulator.op, tensor.EmptyOp) + and op.accumulator.type == next_apply.accumulator.type ): - rewriter.replace_op(next_apply.iter_arg.op, [], [op.iter_arg]) + rewriter.replace_op(next_apply.accumulator.op, [], [op.accumulator]) diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py similarity index 85% rename from xdsl/transforms/stencil_to_csl_stencil.py rename to xdsl/transforms/convert_stencil_to_csl_stencil.py index 58dc788f4b..3f3d128056 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -348,7 +348,7 @@ class ConvertApplyOpPattern(RewritePattern): args: num_chunks - number of chunks into which communication and computation should be split. - Effectively, the number of times `csl_stencil.apply.chunk_reduce` will be executed and the + Effectively, the number of times `csl_stencil.apply.receive_chunk` will be executed and the tensor sizes it handles. Higher values may increase compute overhead but reduce size of communication buffers when lowered. """ @@ -376,18 +376,18 @@ def get_prefetch_overhead(o: OpResult): prefetch = max(candidate_prefetches)[1] prefetch_idx = op.operands.index(prefetch) assert isinstance(prefetch.op, csl_stencil.PrefetchOp) - communicated_stencil_idx = op.operands.index(prefetch.op.input_stencil) + field_idx = op.operands.index(prefetch.op.input_stencil) assert isinstance(prefetch.op, csl_stencil.PrefetchOp) assert isa(prefetch.type, TensorType[Attribute]) - communicated_stencil_op_arg = prefetch.op.input_stencil + field_op_arg = prefetch.op.input_stencil - # add empty tensor before op to be used as `iter_arg` + # add empty tensor before op to be used as `accumulator` # this could potentially be re-used if we have one of the same size lying around - iter_arg = tensor.EmptyOp( + accumulator = tensor.EmptyOp( (), TensorType(prefetch.type.get_element_type(), prefetch.type.get_shape()[1:]), ) - rewriter.insert_op(iter_arg, InsertPoint.before(op)) + rewriter.insert_op(accumulator, InsertPoint.before(op)) # run pass (on this apply's region only) to consume data from `prefetch` accesses first nested_rewriter = PatternRewriteWalker( @@ -397,39 +397,39 @@ def get_prefetch_overhead(o: OpResult): nested_rewriter.rewrite_op(op) # determine how ops should be split across the two regions - chunk_reduce_ops, post_process_ops = get_op_split( + chunk_region_ops, done_exchange_ops = get_op_split( list(op.region.block.ops), op.region.block.args[prefetch_idx] ) - # fetch what chunk_reduce is computing for - if isinstance(chunk_reduce_ops[-1], stencil.ReturnOp): - chunk_res = chunk_reduce_ops[-1].operands[0] + # fetch what receive_chunk is computing for + if isinstance(chunk_region_ops[-1], stencil.ReturnOp): + chunk_res = chunk_region_ops[-1].operands[0] else: - chunk_res = chunk_reduce_ops[-1].results[0] + chunk_res = chunk_region_ops[-1].results[0] # after region split, check which block args (from the old ops block) are being accessed in each of the new regions # ignore accesses block args which already are part of the region's required signature - chunk_reduce_used_block_args = sorted( + chunk_region_used_block_args = sorted( set( x - for o in chunk_reduce_ops + for o in chunk_region_ops for x in o.operands if isinstance(x, BlockArgument) and x.index != prefetch_idx ), key=lambda b: b.index, ) - post_process_used_block_args = sorted( + done_exchange_used_block_args = sorted( set( x - for o in post_process_ops + for o in done_exchange_ops for x in o.operands - if isinstance(x, BlockArgument) and x.index != communicated_stencil_idx + if isinstance(x, BlockArgument) and x.index != field_idx ), key=lambda b: b.index, ) # set up region signatures, comprising fixed and optional args - see docs on `csl_stencil.apply` for details - chunk_reduce_args = [ + chunk_region_args = [ # required arg 0: slice of type(%prefetch) TensorType( prefetch.type.get_element_type(), @@ -440,81 +440,81 @@ def get_prefetch_overhead(o: OpResult): ), # required arg 1: %offset IndexType(), - # required arg 2: %iter_arg - iter_arg.results[0].type, + # required arg 2: %accumulator + accumulator.tensor.type, # optional args: as needed by the ops - *[a.type for a in chunk_reduce_used_block_args], + *[a.type for a in chunk_region_used_block_args], ] - post_process_args = [ + done_exchange_args = [ # required arg 0: stencil.temp to access own data - communicated_stencil_op_arg.type, - # required arg 1: %iter_arg - iter_arg.results[0].type, + field_op_arg.type, + # required arg 1: %accumulator + accumulator.tensor.type, # optional args: as needed by the ops - *[a.type for a in post_process_used_block_args], + *[a.type for a in done_exchange_used_block_args], ] # set up two regions - chunk_reduce = Region(Block(arg_types=chunk_reduce_args)) - post_process = Region(Block(arg_types=post_process_args)) + receive_chunk = Region(Block(arg_types=chunk_region_args)) + done_exchange = Region(Block(arg_types=done_exchange_args)) # translate old to new block arg index for optional args - chunk_reduce_oprnd_table = dict[Operand, Operand]( - (old, chunk_reduce.block.args[idx]) - for idx, old in enumerate(chunk_reduce_used_block_args, start=3) + chunk_region_oprnd_table = dict[Operand, Operand]( + (old, receive_chunk.block.args[idx]) + for idx, old in enumerate(chunk_region_used_block_args, start=3) ) - post_process_oprnd_table = dict[Operand, Operand]( - (old, post_process.block.args[idx]) - for idx, old in enumerate(post_process_used_block_args, start=2) + done_exchange_oprnd_table = dict[Operand, Operand]( + (old, done_exchange.block.args[idx]) + for idx, old in enumerate(done_exchange_used_block_args, start=2) ) - # add translation from old to new arg index for non-optional args - note, access to iter_arg must be handled separately below - chunk_reduce_oprnd_table[op.region.block.args[prefetch_idx]] = ( - chunk_reduce.block.args[0] + # add translation from old to new arg index for non-optional args - note, access to accumulator must be handled separately below + chunk_region_oprnd_table[op.region.block.args[prefetch_idx]] = ( + receive_chunk.block.args[0] ) - post_process_oprnd_table[op.region.block.args[communicated_stencil_idx]] = ( - post_process.block.args[0] + done_exchange_oprnd_table[op.region.block.args[field_idx]] = ( + done_exchange.block.args[0] ) - post_process_oprnd_table[chunk_res] = post_process.block.args[1] + done_exchange_oprnd_table[chunk_res] = done_exchange.block.args[1] # detach ops from old region for o in op.region.block.ops: op.region.block.detach_op(o) - # add operations from list to chunk_reduce, use translation table to rebuild operands - for o in chunk_reduce_ops: + # add operations from list to receive_chunk, use translation table to rebuild operands + for o in chunk_region_ops: if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp): break - o.operands = [chunk_reduce_oprnd_table.get(x, x) for x in o.operands] - chunk_reduce.block.add_op(o) + o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands] + receive_chunk.block.add_op(o) - # put `chunk_res` into `iter_arg` (using tensor.insert_slice) and yield the result - chunk_reduce.block.add_ops( + # put `chunk_res` into `accumulator` (using tensor.insert_slice) and yield the result + receive_chunk.block.add_ops( [ insert_slice_op := tensor.InsertSliceOp.get( source=chunk_res, - dest=chunk_reduce.block.args[2], - offsets=(chunk_reduce.block.args[1],), + dest=receive_chunk.block.args[2], + offsets=(receive_chunk.block.args[1],), static_sizes=(prefetch.type.get_shape()[1] // self.num_chunks,), ), csl_stencil.YieldOp(insert_slice_op.result), ] ) - # add operations from list to post_process, use translation table to rebuild operands - for o in post_process_ops: - o.operands = [post_process_oprnd_table.get(x, x) for x in o.operands] - post_process.block.add_op(o) + # add operations from list to done_exchange, use translation table to rebuild operands + for o in done_exchange_ops: + o.operands = [done_exchange_oprnd_table.get(x, x) for x in o.operands] + done_exchange.block.add_op(o) if isinstance(o, stencil.ReturnOp): rewriter.replace_op(o, csl_stencil.YieldOp(*o.operands)) rewriter.replace_matched_op( csl_stencil.ApplyOp( operands=[ - communicated_stencil_op_arg, - iter_arg, - [op.operands[a.index] for a in chunk_reduce_used_block_args] - + [op.operands[a.index] for a in post_process_used_block_args], + field_op_arg, + accumulator, + [op.operands[a.index] for a in chunk_region_used_block_args] + + [op.operands[a.index] for a in done_exchange_used_block_args], op.dest, ], properties={ @@ -524,8 +524,8 @@ def get_prefetch_overhead(o: OpResult): "bounds": op.bounds, }, regions=[ - chunk_reduce, - post_process, + receive_chunk, + done_exchange, ], result_types=[op.result_types], ) @@ -536,8 +536,8 @@ def get_prefetch_overhead(o: OpResult): @dataclass(frozen=True) -class StencilToCslStencilPass(ModulePass): - name = "stencil-to-csl-stencil" +class ConvertStencilToCslStencilPass(ModulePass): + name = "convert-stencil-to-csl-stencil" # chunks into which to slice communication num_chunks: int = 1 diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index 5024bdf29c..42f087be42 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -77,12 +77,12 @@ class ApplyOpBufferize(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): - if isa(op.iter_arg.type, memref.MemRefType[Attribute]): + if isa(op.accumulator.type, memref.MemRefType[Attribute]): return # convert args buf_args: list[SSAValue] = [] - to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.iter_arg)] + to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)] for arg in op.args: if isa(arg.type, TensorType[Attribute]): to_memrefs.append(new_arg := to_memref_op(arg)) @@ -92,64 +92,64 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # create new op buf_apply_op = csl_stencil.ApplyOp( - operands=[op.communicated_stencil, buf_iter_arg.memref, op.args, op.dest], + operands=[op.field, buf_iter_arg.memref, op.args, op.dest], result_types=op.res.types or [[]], regions=[ - self._get_empty_bufferized_region(op.chunk_reduce.block.args), - self._get_empty_bufferized_region(op.post_process.block.args), + self._get_empty_bufferized_region(op.receive_chunk.block.args), + self._get_empty_bufferized_region(op.done_exchange.block.args), ], properties=op.properties, attributes=op.attributes, ) # insert to_tensor ops and create arg mappings for block inlining - chunk_reduce_arg_mapping: Sequence[SSAValue] = [] + chunk_region_arg_mapping: Sequence[SSAValue] = [] for idx, (old_arg, arg) in enumerate( - zip(op.chunk_reduce.block.args, buf_apply_op.chunk_reduce.block.args) + zip(op.receive_chunk.block.args, buf_apply_op.receive_chunk.block.args) ): # arg0 has special meaning and does not need a `to_tensor` op if isattr(old_arg.type, TensorType) and idx != 0: rewriter.insert_op( # ensure iter_arg is writable t := to_tensor_op(arg, writable=idx == 2), - InsertPoint.at_end(buf_apply_op.chunk_reduce.block), + InsertPoint.at_end(buf_apply_op.receive_chunk.block), ) - chunk_reduce_arg_mapping.append(t.tensor) + chunk_region_arg_mapping.append(t.tensor) else: - chunk_reduce_arg_mapping.append(arg) + chunk_region_arg_mapping.append(arg) - post_process_arg_mapping: Sequence[SSAValue] = [] + done_exchange_arg_mapping: Sequence[SSAValue] = [] for idx, (old_arg, arg) in enumerate( - zip(op.post_process.block.args, buf_apply_op.post_process.block.args) + zip(op.done_exchange.block.args, buf_apply_op.done_exchange.block.args) ): if isattr(old_arg.type, TensorType): rewriter.insert_op( # ensure iter_arg is writable t := to_tensor_op(arg, writable=idx == 1), - InsertPoint.at_end(buf_apply_op.post_process.block), + InsertPoint.at_end(buf_apply_op.done_exchange.block), ) - post_process_arg_mapping.append(t.tensor) + done_exchange_arg_mapping.append(t.tensor) else: - post_process_arg_mapping.append(arg) + done_exchange_arg_mapping.append(arg) - assert isa(typ := op.chunk_reduce.block.args[0].type, TensorType[Attribute]) + assert isa(typ := op.receive_chunk.block.args[0].type, TensorType[Attribute]) chunk_type = TensorType(typ.get_element_type(), typ.get_shape()[1:]) # inline blocks from old into new regions rewriter.inline_block( - op.chunk_reduce.block, - InsertPoint.at_end(buf_apply_op.chunk_reduce.block), - chunk_reduce_arg_mapping, + op.receive_chunk.block, + InsertPoint.at_end(buf_apply_op.receive_chunk.block), + chunk_region_arg_mapping, ) rewriter.inline_block( - op.post_process.block, - InsertPoint.at_end(buf_apply_op.post_process.block), - post_process_arg_mapping, + op.done_exchange.block, + InsertPoint.at_end(buf_apply_op.done_exchange.block), + done_exchange_arg_mapping, ) self._inject_iter_arg_into_linalg_outs( - buf_apply_op, rewriter, chunk_type, chunk_reduce_arg_mapping[2] + buf_apply_op, rewriter, chunk_type, chunk_region_arg_mapping[2] ) # insert new op @@ -185,7 +185,7 @@ def _inject_iter_arg_into_linalg_outs( and avoiding having an extra alloc + memref.copy. """ linalg_op: linalg.NamedOpBase | None = None - for curr_op in op.chunk_reduce.block.ops: + for curr_op in op.receive_chunk.block.ops: if ( isinstance(curr_op, linalg.NamedOpBase) and len(curr_op.outputs) > 0 @@ -201,7 +201,7 @@ def _inject_iter_arg_into_linalg_outs( linalg_op, [ extract_slice_op := tensor.ExtractSliceOp( - operands=[iter_arg, [op.chunk_reduce.block.args[1]], [], []], + operands=[iter_arg, [op.receive_chunk.block.args[1]], [], []], result_types=[chunk_type], properties={ "static_offsets": DenseArrayBase.from_list( @@ -228,11 +228,11 @@ def _build_extract_slice( op: csl_stencil.ApplyOp, to_tensor: bufferization.ToTensorOp, offset: SSAValue ) -> tensor.ExtractSliceOp: """ - Helper function to create an early tensor.extract_slice in the apply.chunk_reduce region needed for bufferization. + Helper function to create an early tensor.extract_slice in the apply.recv_chunk_cb region needed for bufferization. """ # this is the unbufferized `tensor<(neighbours)x(ZDim)x(type)>` value - assert isa(typ := op.chunk_reduce.block.args[0].type, TensorType[Attribute]) + assert isa(typ := op.receive_chunk.block.args[0].type, TensorType[Attribute]) return tensor.ExtractSliceOp( operands=[to_tensor.tensor, [offset], [], []], @@ -369,7 +369,7 @@ class CslStencilBufferize(ModulePass): """ Bufferizes the csl_stencil dialect. - Attempts to inject `csl_stencil.apply.chunk_reduce.iter_arg` into linalg compute ops `outs` within that region + Attempts to inject `csl_stencil.apply.recv_chunk_cb.accumulator` into linalg compute ops `outs` within that region for improved bufferization. Ideally be run after `--lift-arith-to-linalg`. """ diff --git a/xdsl/transforms/csl_stencil_handle_async_flow.py b/xdsl/transforms/csl_stencil_handle_async_flow.py index 4519dc6d44..1f77062313 100644 --- a/xdsl/transforms/csl_stencil_handle_async_flow.py +++ b/xdsl/transforms/csl_stencil_handle_async_flow.py @@ -51,7 +51,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, if isinstance(op.next_op, csl.ReturnOp): return - terminator = op.post_process.block.last_op + terminator = op.done_exchange.block.last_op assert isinstance(terminator, csl_stencil.YieldOp) # case 2: apply is followed by call_op and return - move call_op to second callback of apply diff --git a/xdsl/transforms/csl_stencil_to_csl_wrapper.py b/xdsl/transforms/csl_stencil_to_csl_wrapper.py index c6049dc13b..2c58d75014 100644 --- a/xdsl/transforms/csl_stencil_to_csl_wrapper.py +++ b/xdsl/transforms/csl_stencil_to_csl_wrapper.py @@ -75,16 +75,16 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): raise ValueError("Stencil accesses must be 2-dimensional at this stage") # find max z dimension - we could get this from func args, store ops, or apply ops - # to support both bufferized and unbufferized csl_stencils, retrieve this from iter_arg - if isinstance(apply_op.post_process.block.args[1].type, ShapedType): + # to support both bufferized and unbufferized csl_stencils, retrieve this from accumulator + if isinstance(apply_op.done_exchange.block.args[1].type, ShapedType): z_dim_no_ghost_cells = max( z_dim_no_ghost_cells, - apply_op.post_process.block.args[1].type.get_shape()[-1], + apply_op.done_exchange.block.args[1].type.get_shape()[-1], ) - # retrieve z_dim from post_process arg[0] + # retrieve z_dim from done_exchange arg[0] if isa( - field_t := apply_op.post_process.block.args[0].type, + field_t := apply_op.done_exchange.block.args[0].type, stencil.StencilType[ TensorType[Attribute] | memref.MemRefType[Attribute] ], @@ -97,7 +97,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): num_chunks = max(num_chunks, apply_op.num_chunks.value.data) if isa( - buf_t := apply_op.chunk_reduce.block.args[0].type, + buf_t := apply_op.receive_chunk.block.args[0].type, TensorType[Attribute] | MemRefType[Attribute], ): chunk_size = max(chunk_size, buf_t.get_shape()[-1]) diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 69f1b4e303..1d1d3f1fa4 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -121,68 +121,71 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, ), "Expected csl_stencil.apply to be inside a func.func or csl.func" # set up csl funcs - reduce_fn = csl.FuncOp( - "chunk_reduce_cb" + str(self.count), FunctionType.from_lists([i16], []) + chunk_fn = csl.FuncOp( + "receive_chunk_cb" + str(self.count), FunctionType.from_lists([i16], []) ) - post_fn = csl.FuncOp( - "post_process_cb" + str(self.count), + chunk_fn.body.block.args[0].name_hint = "offset" + done_fn = csl.FuncOp( + "done_exchange_cb" + str(self.count), FunctionType.from_lists([], []), Region(Block()), ) self.count += 1 # the offset arg was of type index and is now i16, so it's cast back to index to be used in the func body - reduce_fn.body.block.add_op( + chunk_fn.body.block.add_op( index_op := arith.IndexCastOp( - reduce_fn.body.block.args[0], + chunk_fn.body.block.args[0], IndexType(), ) ) # arg maps for the regions - reduce_arg_m = [ - op.communicated_stencil, # buffer - this is a placeholder and should not be used after lowering AccessOp + chunk_arg_m = [ + op.field, # buffer - this is a placeholder and should not be used after lowering AccessOp index_op.result, - op.iter_arg, - *op.args[: len(op.chunk_reduce.block.args) - 3], + op.accumulator, + *op.args[: len(op.receive_chunk.block.args) - 3], ] - post_arg_m = [ - op.communicated_stencil, - op.iter_arg, - *op.args[len(reduce_arg_m) - 3 :], + done_arg_m = [ + op.field, + op.accumulator, + *op.args[len(chunk_arg_m) - 3 :], ] + index_op.result.name_hint = "offset" + op.accumulator.name_hint = "accumulator" # inlining both regions rewriter.inline_block( - op.chunk_reduce.block, - InsertPoint.at_end(reduce_fn.body.block), - reduce_arg_m, + op.receive_chunk.block, + InsertPoint.at_end(chunk_fn.body.block), + chunk_arg_m, ) rewriter.inline_block( - op.post_process.block, InsertPoint.at_end(post_fn.body.block), post_arg_m + op.done_exchange.block, InsertPoint.at_end(done_fn.body.block), done_arg_m ) # place both func next to the enclosing parent func - rewriter.insert_op([reduce_fn, post_fn], InsertPoint.after(parent_func)) + rewriter.insert_op([chunk_fn, done_fn], InsertPoint.after(parent_func)) # add api call num_chunks = arith.Constant(IntegerAttr(op.num_chunks.value, i16)) - reduce_ref = csl.AddressOfFnOp(reduce_fn) - post_ref = csl.AddressOfFnOp(post_fn) + chunk_ref = csl.AddressOfFnOp(chunk_fn) + done_ref = csl.AddressOfFnOp(done_fn) api_call = csl.MemberCallOp( "communicate", None, module_wrapper_op.get_program_import("stencil_comms.csl"), [ - op.communicated_stencil, + op.field, num_chunks, - reduce_ref, - post_ref, + chunk_ref, + done_ref, ], ) # replace op with api call - rewriter.replace_matched_op([num_chunks, reduce_ref, post_ref, api_call], []) + rewriter.replace_matched_op([num_chunks, chunk_ref, done_ref, api_call], []) @dataclass(frozen=True) @@ -198,7 +201,7 @@ def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, assert isinstance(apply := op.parent_op(), csl_stencil.ApplyOp) # the second callback stores yielded values to dest - if op.parent_region() == apply.post_process: + if op.parent_region() == apply.done_exchange: views: list[Operation] = [] for src, dst in zip(op.arguments, apply.dest): assert isa(src.type, memref.MemRefType[Attribute]) diff --git a/xdsl/transforms/csl_wrapper_to_csl.py b/xdsl/transforms/lower_csl_wrapper.py similarity index 99% rename from xdsl/transforms/csl_wrapper_to_csl.py rename to xdsl/transforms/lower_csl_wrapper.py index 7a9ae567cf..b8b0cf65f2 100644 --- a/xdsl/transforms/csl_wrapper_to_csl.py +++ b/xdsl/transforms/lower_csl_wrapper.py @@ -310,10 +310,10 @@ def match_and_rewrite(self, op: csl_wrapper.ImportOp, rewriter: PatternRewriter, @dataclass(frozen=True) -class CslWrapperToCslPass(ModulePass): +class LowerCslWrapperPass(ModulePass): """Unwraps the `csl_wrappermodule` into two `csl.module`s.""" - name = "csl-wrapper-to-csl" + name = "lower-csl-wrapper" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker(