-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug: (csl-lowering) Make multi-apply lowering work #3614
Changes from 4 commits
cd122d9
1bd408f
8078960
2790e75
8aff3ec
43a1709
e662fb4
3f94a4a
071e042
e166626
2374a1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -608,17 +608,17 @@ def match_and_rewrite( | |
block2 = Block(arg_types=[op.input_stencil.type, op.result.type]) | ||
block2.add_op(csl_stencil.YieldOp(block2.args[1])) | ||
|
||
with ImplicitBuilder(block) as (_, offset, acc): | ||
with ImplicitBuilder(block) as (buf, offset, acc): | ||
dest = acc | ||
for i, acc_offset in enumerate(offsets): | ||
ac_op = csl_stencil.AccessOp( | ||
dest, stencil.IndexAttr.get(*acc_offset), chunk_t | ||
buf, stencil.IndexAttr.get(*acc_offset), chunk_t | ||
) | ||
assert isa(ac_op.result.type, AnyTensorType) | ||
dest = tensor.InsertSliceOp.get( | ||
source=ac_op.result, | ||
dest=dest, | ||
static_sizes=ac_op.result.type.get_shape(), | ||
static_sizes=[1, *ac_op.result.type.get_shape()], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs 1 in the 1st dimension to insert 1 slice of z-values ( |
||
static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX], | ||
offsets=[offset], | ||
).result | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, | |
# convert args | ||
buf_args: list[SSAValue] = [] | ||
to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)] | ||
op.accumulator.replace_by_if( | ||
buf_iter_arg.memref, lambda use: use.operation != buf_iter_arg | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Allowing the |
||
for arg in [*op.args_rchunk, *op.args_dexchng]: | ||
if isa(arg.type, TensorType[Attribute]): | ||
to_memrefs.append(new_arg := to_memref_op(arg)) | ||
|
@@ -385,6 +388,9 @@ def match_and_rewrite(self, op: arith.ConstantOp, rewriter: PatternRewriter, /): | |
class InjectApplyOutsIntoLinalgOuts(RewritePattern): | ||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): | ||
if not op.dest: | ||
return | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding support for zero output apply ops. |
||
yld = op.done_exchange.block.last_op | ||
assert isinstance(yld, csl_stencil.YieldOp) | ||
new_dest: list[SSAValue] = [] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,9 +84,13 @@ def get_required_result_type(op: Operation) -> TensorType[Attribute] | None: | |
tuple[int, ...], | ||
) | ||
): | ||
assert is_tensor(use.operation.source.type) | ||
# inserting a 1d tensor into a 2d array should not require the input tensor to also be 2d | ||
# instead, drop the first `dimdiff` dimensions | ||
dimdiff = len(static_sizes) - len(use.operation.source.type.shape) | ||
return TensorType( | ||
use.operation.result.type.get_element_type(), | ||
static_sizes, | ||
static_sizes[dimdiff:], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small fix for inserting a 1d slice into a 2d tensor. |
||
) | ||
for ret in use.operation.results: | ||
if isa(r_type := ret.type, TensorType[Attribute]): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Read from the buf, not from the dest (which is the accumulator).