diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index 8031f867e9..e550a04561 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -238,12 +238,12 @@ builtin.module { // CHECK-NEXT: %0 = tensor.empty() : tensor<1x64xf32> // CHECK-NEXT: csl_stencil.apply(%arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) -> () <{"swaps" = [#csl_stencil.exchange], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%1 : tensor<1x32xf32>, %2 : index, %3 : tensor<1x64xf32>): -// CHECK-NEXT: %4 = csl_stencil.access %3[-1, 0] : tensor<1x64xf32> -// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32> +// CHECK-NEXT: %4 = csl_stencil.access %1[-1, 0] : tensor<1x32xf32> +// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32> // CHECK-NEXT: csl_stencil.yield %5 : tensor<1x64xf32> // CHECK-NEXT: }, { // CHECK-NEXT: ^1(%6 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %7 : tensor<1x64xf32>): -// CHECK-NEXT: csl_stencil.yield %7 : tensor<1x64xf32> +// CHECK-NEXT: csl_stencil.yield // CHECK-NEXT: }) // CHECK-NEXT: %1 = tensor.empty() : tensor<64xf32> // CHECK-NEXT: csl_stencil.apply(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %1 : tensor<64xf32>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) <{"swaps" = [#csl_stencil.exchange], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 76981530d4..8551846aee 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -606,19 +606,21 @@ def match_and_rewrite( block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type]) block2 = Block(arg_types=[op.input_stencil.type, op.result.type]) - block2.add_op(csl_stencil.YieldOp(block2.args[1])) + block2.add_op(csl_stencil.YieldOp()) - with ImplicitBuilder(block) as (_, offset, acc): + with ImplicitBuilder(block) as (buf, offset, acc): dest = acc for i, acc_offset in enumerate(offsets): ac_op = csl_stencil.AccessOp( - dest, stencil.IndexAttr.get(*acc_offset), chunk_t + buf, stencil.IndexAttr.get(*acc_offset), chunk_t ) assert isa(ac_op.result.type, AnyTensorType) + # inserts 1 (see static_sizes) 1d slice into a 2d tensor at offset (i, `offset`) (see static_offsets) + # where the latter offset is provided dynamically (see offsets) dest = tensor.InsertSliceOp.get( source=ac_op.result, dest=dest, - static_sizes=ac_op.result.type.get_shape(), + static_sizes=[1, *ac_op.result.type.get_shape()], static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX], offsets=[offset], ).result diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index 92d2ea9708..bd4f46d9fe 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -94,6 +94,10 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # convert args buf_args: list[SSAValue] = [] to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)] + # in case of subsequent apply ops accessing this accumulator, replace uses with `bufferization.to_memref` + op.accumulator.replace_by_if( + buf_iter_arg.memref, lambda use: use.operation != buf_iter_arg + ) for arg in [*op.args_rchunk, *op.args_dexchng]: if isa(arg.type, TensorType[Attribute]): to_memrefs.append(new_arg := to_memref_op(arg)) @@ -385,6 +389,11 @@ def match_and_rewrite(self, op: arith.ConstantOp, rewriter: PatternRewriter, /): class InjectApplyOutsIntoLinalgOuts(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): + # require bufferized apply (with op.dest specified) + # zero-output apply ops may be used for communicate-only, to which this pattern does not apply + if not op.dest: + return + yld = op.done_exchange.block.last_op assert isinstance(yld, csl_stencil.YieldOp) new_dest: list[SSAValue] = [] diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 473d6b3628..ddccef2eec 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -84,9 +84,13 @@ def get_required_result_type(op: Operation) -> TensorType[Attribute] | None: tuple[int, ...], ) ): + assert is_tensor(use.operation.source.type) + # inserting an (n-1)d tensor into an (n)d tensor should not require the input tensor to also be (n)d + # instead, drop the first `dimdiff` dimensions + dimdiff = len(static_sizes) - len(use.operation.source.type.shape) return TensorType( use.operation.result.type.get_element_type(), - static_sizes, + static_sizes[dimdiff:], ) for ret in use.operation.results: if isa(r_type := ret.type, TensorType[Attribute]): diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index 98f4a59a7d..a14b06e2b3 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -133,7 +133,12 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /): last_op = stride_ops[-1] if len(stride_ops) > 0 else last_op offset_ops = self._update_offsets(op, last_op) - rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops]) + new_ops = [*size_ops, *stride_ops, *offset_ops] + if new_ops: + rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops]) + else: + # subview has no effect (todo: this could be canonicalized away) + rewriter.replace_matched_op([], new_results=[op.source]) @staticmethod def _update_sizes(