Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: (csl-lowering) Make multi-apply lowering work #3614

Merged
merged 11 commits into from
Dec 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ builtin.module {
// CHECK-NEXT: %0 = tensor.empty() : tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.apply(%arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) -> () <{"swaps" = [#csl_stencil.exchange<to [-1, 0]>], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>}> ({
// CHECK-NEXT: ^0(%1 : tensor<1x32xf32>, %2 : index, %3 : tensor<1x64xf32>):
// CHECK-NEXT: %4 = csl_stencil.access %3[-1, 0] : tensor<1x64xf32>
// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array<i64: 0, -9223372036854775808>, "static_sizes" = array<i64: 32>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32>
// CHECK-NEXT: %4 = csl_stencil.access %1[-1, 0] : tensor<1x32xf32>
// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array<i64: 0, -9223372036854775808>, "static_sizes" = array<i64: 1, 32>, "static_strides" = array<i64: 1, 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.yield %5 : tensor<1x64xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%6 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %7 : tensor<1x64xf32>):
// CHECK-NEXT: csl_stencil.yield %7 : tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.yield
// CHECK-NEXT: })
// CHECK-NEXT: %1 = tensor.empty() : tensor<64xf32>
// CHECK-NEXT: csl_stencil.apply(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %1 : tensor<64xf32>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) <{"swaps" = [#csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 2, 1>}> ({
Expand Down
10 changes: 6 additions & 4 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,19 +606,21 @@ def match_and_rewrite(

block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type])
block2 = Block(arg_types=[op.input_stencil.type, op.result.type])
block2.add_op(csl_stencil.YieldOp(block2.args[1]))
block2.add_op(csl_stencil.YieldOp())

with ImplicitBuilder(block) as (_, offset, acc):
with ImplicitBuilder(block) as (buf, offset, acc):
dest = acc
for i, acc_offset in enumerate(offsets):
ac_op = csl_stencil.AccessOp(
dest, stencil.IndexAttr.get(*acc_offset), chunk_t
buf, stencil.IndexAttr.get(*acc_offset), chunk_t
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read from the buf, not from the dest (which is the accumulator).

)
assert isa(ac_op.result.type, AnyTensorType)
# inserts 1 (see static_sizes) 1d slice into a 2d tensor at offset (i, `offset`) (see static_offsets)
# where the latter offset is provided dynamically (see offsets)
dest = tensor.InsertSliceOp.get(
source=ac_op.result,
dest=dest,
static_sizes=ac_op.result.type.get_shape(),
static_sizes=[1, *ac_op.result.type.get_shape()],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs 1 in the 1st dimension to insert 1 slice of z-values (ac_op.result.type.get_shape() are the shape of the z-values)

static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX],
offsets=[offset],
).result
Expand Down
9 changes: 9 additions & 0 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
# convert args
buf_args: list[SSAValue] = []
to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)]
# in case of subsequent apply ops accessing this accumulator, replace uses with `bufferization.to_memref`
op.accumulator.replace_by_if(
buf_iter_arg.memref, lambda use: use.operation != buf_iter_arg
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing the bufferization.to_memref (buf_iter_arg) to be used by further apply ops, in case there is more than one.

for arg in [*op.args_rchunk, *op.args_dexchng]:
if isa(arg.type, TensorType[Attribute]):
to_memrefs.append(new_arg := to_memref_op(arg))
Expand Down Expand Up @@ -385,6 +389,11 @@ def match_and_rewrite(self, op: arith.ConstantOp, rewriter: PatternRewriter, /):
class InjectApplyOutsIntoLinalgOuts(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /):
# require bufferized apply (with op.dest specified)
# zero-output apply ops may be used for communicate-only, to which this pattern does not apply
if not op.dest:
return

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding support for zero output apply ops.

yld = op.done_exchange.block.last_op
assert isinstance(yld, csl_stencil.YieldOp)
new_dest: list[SSAValue] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def get_required_result_type(op: Operation) -> TensorType[Attribute] | None:
tuple[int, ...],
)
):
assert is_tensor(use.operation.source.type)
# inserting an (n-1)d tensor into an (n)d tensor should not require the input tensor to also be (n)d
# instead, drop the first `dimdiff` dimensions
dimdiff = len(static_sizes) - len(use.operation.source.type.shape)
return TensorType(
use.operation.result.type.get_element_type(),
static_sizes,
static_sizes[dimdiff:],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small fix for inserting a 1d slice into a 2d tensor.

)
for ret in use.operation.results:
if isa(r_type := ret.type, TensorType[Attribute]):
Expand Down
7 changes: 6 additions & 1 deletion xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
last_op = stride_ops[-1] if len(stride_ops) > 0 else last_op
offset_ops = self._update_offsets(op, last_op)

rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops])
new_ops = [*size_ops, *stride_ops, *offset_ops]
if new_ops:
rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops])
else:
# subview has no effect (todo: this could be canonicalized away)
rewriter.replace_matched_op([], new_results=[op.source])

@staticmethod
def _update_sizes(
Expand Down
Loading