Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: (csl-lowering) Make multi-apply lowering work #3614

Merged
merged 11 commits into from
Dec 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ builtin.module {
// CHECK-NEXT: %0 = tensor.empty() : tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.apply(%arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) -> () <{"swaps" = [#csl_stencil.exchange<to [-1, 0]>], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>}> ({
// CHECK-NEXT: ^0(%1 : tensor<1x32xf32>, %2 : index, %3 : tensor<1x64xf32>):
// CHECK-NEXT: %4 = csl_stencil.access %3[-1, 0] : tensor<1x64xf32>
// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array<i64: 0, -9223372036854775808>, "static_sizes" = array<i64: 32>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32>
// CHECK-NEXT: %4 = csl_stencil.access %1[-1, 0] : tensor<1x32xf32>
// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array<i64: 0, -9223372036854775808>, "static_sizes" = array<i64: 1, 32>, "static_strides" = array<i64: 1, 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.yield %5 : tensor<1x64xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%6 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %7 : tensor<1x64xf32>):
Expand Down
6 changes: 3 additions & 3 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,17 +608,17 @@ def match_and_rewrite(
block2 = Block(arg_types=[op.input_stencil.type, op.result.type])
block2.add_op(csl_stencil.YieldOp(block2.args[1]))

with ImplicitBuilder(block) as (_, offset, acc):
with ImplicitBuilder(block) as (buf, offset, acc):
dest = acc
for i, acc_offset in enumerate(offsets):
ac_op = csl_stencil.AccessOp(
dest, stencil.IndexAttr.get(*acc_offset), chunk_t
buf, stencil.IndexAttr.get(*acc_offset), chunk_t
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read from the buf, not from the dest (which is the accumulator).

)
assert isa(ac_op.result.type, AnyTensorType)
dest = tensor.InsertSliceOp.get(
source=ac_op.result,
dest=dest,
static_sizes=ac_op.result.type.get_shape(),
static_sizes=[1, *ac_op.result.type.get_shape()],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs 1 in the 1st dimension to insert 1 slice of z-values (ac_op.result.type.get_shape() are the shape of the z-values)

static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX],
offsets=[offset],
).result
Expand Down
6 changes: 6 additions & 0 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
# convert args
buf_args: list[SSAValue] = []
to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)]
op.accumulator.replace_by_if(
buf_iter_arg.memref, lambda use: use.operation != buf_iter_arg
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing the bufferization.to_memref (buf_iter_arg) to be used by further apply ops, in case there is more than one.

for arg in [*op.args_rchunk, *op.args_dexchng]:
if isa(arg.type, TensorType[Attribute]):
to_memrefs.append(new_arg := to_memref_op(arg))
Expand Down Expand Up @@ -385,6 +388,9 @@ def match_and_rewrite(self, op: arith.ConstantOp, rewriter: PatternRewriter, /):
class InjectApplyOutsIntoLinalgOuts(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /):
if not op.dest:
return

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding support for zero output apply ops.

yld = op.done_exchange.block.last_op
assert isinstance(yld, csl_stencil.YieldOp)
new_dest: list[SSAValue] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def get_required_result_type(op: Operation) -> TensorType[Attribute] | None:
tuple[int, ...],
)
):
assert is_tensor(use.operation.source.type)
# inserting a 1d tensor into a 2d array should not require the input tensor to also be 2d
# instead, drop the first `dimdiff` dimensions
dimdiff = len(static_sizes) - len(use.operation.source.type.shape)
return TensorType(
use.operation.result.type.get_element_type(),
static_sizes,
static_sizes[dimdiff:],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small fix for inserting a 1d slice into a 2d tensor.

)
for ret in use.operation.results:
if isa(r_type := ret.type, TensorType[Attribute]):
Expand Down
Loading