Skip to content

Commit

Permalink
accumulator fix
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io committed Dec 20, 2024
1 parent e902129 commit 1ffe566
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,25 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
# ensure we send only core data
assert isa(op.accumulator.type, memref.MemRefType[Attribute])
assert isa(op.field.type, memref.MemRefType[Attribute])
# the accumulator might have additional dims when used for holding prefetched data
send_buf_shape = op.accumulator.type.get_shape()[
-len(op.field.type.get_shape()) :
]
send_buf = memref.SubviewOp.get(
op.field,
[
(d - s) // 2 # symmetric offset
for s, d in zip(
op.accumulator.type.get_shape(), op.field.type.get_shape()
)
for s, d in zip(send_buf_shape, op.field.type.get_shape(), strict=True)
],
op.accumulator.type.get_shape(),
len(op.accumulator.type.get_shape()) * [1],
op.accumulator.type,
send_buf_shape,
len(send_buf_shape) * [1],
memref.MemRefType(op.field.type.get_element_type(), send_buf_shape),
)

# add api call
num_chunks = arith.ConstantOp(IntegerAttr(op.num_chunks.value, i16))
chunk_ref = csl.AddressOfFnOp(chunk_fn)
done_ref = csl.AddressOfFnOp(done_fn)
# send_buf = memref.Subview.get(op.field, [], op.accumulator.type.get_shape(), )
api_call = csl.MemberCallOp(
"communicate",
None,
Expand Down

0 comments on commit 1ffe566

Please sign in to comment.