Skip to content

Commit

Permalink
enhancement: Changing names and namehints in the csl pipeline (#3204)
Browse files Browse the repository at this point in the history
Passes:
* `stencil-to-csl-stencil` -> `convert-stencil-to-csl-stencil`
 * `csl-wrapper-to-csl` -> `lower-csl-wrapper`

Dialects:
* csl_stencil.apply.chunk_reduce region -> `receive_chunk`
* csl_stencil.apply.post_process region -> `done_exchange`
* csl_stencil.apply.iter_arg -> `accumulator`
* csl_stencil.apply.communicated_stencil -> `field`

Namehints:
* offset param in callback: `%offset`
* `%accumulator`


The pass re-naming is to unify naming across the pipeline, i.e.:
* `convert-<dialect>-to-<dialect>`
* `lift-<dialect>-to-<dialect>`
* `lower-<dialect>`
* `<dialect>-do-transformation-within-the-dialect`

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Sep 24, 2024
1 parent b9620a2 commit 1606a5c
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 217 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xdsl-opt %s -p "stencil-to-csl-stencil{num_chunks=2}" | filecheck %s
// RUN: xdsl-opt %s -p "convert-stencil-to-csl-stencil{num_chunks=2}" | filecheck %s

builtin.module {
// CHECK-NEXT: builtin.module {
Expand Down
68 changes: 34 additions & 34 deletions tests/filecheck/transforms/lower-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,46 +95,46 @@ builtin.module {
// CHECK-NEXT: "csl.export"(%36) <{"var_name" = "arg1", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> ()
// CHECK-NEXT: csl.func @gauss_seidel_func() {
// CHECK-NEXT: %arg4 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: %accumulator = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: %37 = arith.constant 2 : i16
// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @chunk_reduce_cb0}> : () -> !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @post_process_cb0}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb0}> : () -> !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb0}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: "csl.member_call"(%34, %arg0, %37, %38, %39) <{"field" = "communicate"}> : (!csl.imported_module, memref<512xf32>, i16, !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>, !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>) -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @chunk_reduce_cb0(%40 : i16) {
// CHECK-NEXT: %arg3 = arith.index_cast %40 : i16 to index
// CHECK-NEXT: %41 = arith.constant 1 : i16
// CHECK-NEXT: %42 = "csl.get_dir"() <{"dir" = #csl<dir_kind west>}> : () -> !csl.direction
// CHECK-NEXT: %43 = "csl.member_call"(%34, %42, %41) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %44 = builtin.unrealized_conversion_cast %43 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %45 = arith.constant 1 : i16
// CHECK-NEXT: %46 = "csl.get_dir"() <{"dir" = #csl<dir_kind east>}> : () -> !csl.direction
// CHECK-NEXT: %47 = "csl.member_call"(%34, %46, %45) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %48 = builtin.unrealized_conversion_cast %47 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %49 = arith.constant 1 : i16
// CHECK-NEXT: %50 = "csl.get_dir"() <{"dir" = #csl<dir_kind south>}> : () -> !csl.direction
// CHECK-NEXT: %51 = "csl.member_call"(%34, %50, %49) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %52 = builtin.unrealized_conversion_cast %51 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %53 = arith.constant 1 : i16
// CHECK-NEXT: %54 = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction
// CHECK-NEXT: %55 = "csl.member_call"(%34, %54, %53) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %56 = builtin.unrealized_conversion_cast %55 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %57 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "csl.fadds"(%57, %56, %52) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%57, %57, %48) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%57, %57, %44) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: csl.func @receive_chunk_cb0(%offset : i16) {
// CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index
// CHECK-NEXT: %40 = arith.constant 1 : i16
// CHECK-NEXT: %41 = "csl.get_dir"() <{"dir" = #csl<dir_kind west>}> : () -> !csl.direction
// CHECK-NEXT: %42 = "csl.member_call"(%34, %41, %40) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %43 = builtin.unrealized_conversion_cast %42 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %44 = arith.constant 1 : i16
// CHECK-NEXT: %45 = "csl.get_dir"() <{"dir" = #csl<dir_kind east>}> : () -> !csl.direction
// CHECK-NEXT: %46 = "csl.member_call"(%34, %45, %44) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %47 = builtin.unrealized_conversion_cast %46 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %48 = arith.constant 1 : i16
// CHECK-NEXT: %49 = "csl.get_dir"() <{"dir" = #csl<dir_kind south>}> : () -> !csl.direction
// CHECK-NEXT: %50 = "csl.member_call"(%34, %49, %48) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %51 = builtin.unrealized_conversion_cast %50 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %52 = arith.constant 1 : i16
// CHECK-NEXT: %53 = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction
// CHECK-NEXT: %54 = "csl.member_call"(%34, %53, %52) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %55 = builtin.unrealized_conversion_cast %54 : !csl<dsd mem1d_dsd> to memref<255xf32>
// CHECK-NEXT: %56 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "csl.fadds"(%56, %55, %51) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%56, %56, %47) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%56, %56, %43) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @post_process_cb0() {
// CHECK-NEXT: %58 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>>
// CHECK-NEXT: %59 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>>
// CHECK-NEXT: "csl.fadds"(%arg4, %arg4, %59) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> ()
// CHECK-NEXT: "csl.fadds"(%arg4, %arg4, %58) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> ()
// CHECK-NEXT: %60 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: "csl.fmuls"(%arg4, %arg4, %60) : (memref<510xf32>, memref<510xf32>, f32) -> ()
// CHECK-NEXT: %61 = memref.subview %arg1[1] [510] [1] : memref<512xf32> to memref<510xf32>
// CHECK-NEXT: "memref.copy"(%arg4, %61) : (memref<510xf32>, memref<510xf32>) -> ()
// CHECK-NEXT: csl.func @done_exchange_cb0() {
// CHECK-NEXT: %57 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>>
// CHECK-NEXT: %58 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>>
// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %58) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> ()
// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %57) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> ()
// CHECK-NEXT: %59 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %59) : (memref<510xf32>, memref<510xf32>, f32) -> ()
// CHECK-NEXT: %60 = memref.subview %arg1[1] [510] [1] : memref<512xf32> to memref<510xf32>
// CHECK-NEXT: "memref.copy"(%accumulator, %60) : (memref<510xf32>, memref<510xf32>) -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> ()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xdsl-opt -p csl-wrapper-to-csl %s | filecheck --match-full-lines %s
// RUN: xdsl-opt -p lower-csl-wrapper %s | filecheck --match-full-lines %s

builtin.module {
"csl_wrapper.module"() <{
Expand Down
88 changes: 44 additions & 44 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.csl_stencil import (
RedundantIterArgInitialisation,
RedundantAccumulatorInitialisation,
)

return (RedundantIterArgInitialisation(),)
return (RedundantAccumulatorInitialisation(),)


@irdl_op_definition
Expand All @@ -169,15 +169,15 @@ class ApplyOp(IRDLOperation):
with a `stencil.apply` (a stencil function plus parameters and applies the stencil function to the output temp).
As communication may be done in chunks, this operation provides two regions for computation:
- the `chunk_reduce` region to reduce a chunk of data received from several neighbours to one chunk of data.
- the `receive_chunk` region to reduce a chunk of data received from several neighbours to one chunk of data.
this region is invoked once per communicated chunks and effectively acts as a loop body.
It uses `iter_arg` to concatenate the chunks
- the `post_process` region (invoked once when communication has finished) that takes the concatenated
chunk of the `chunk_reduce` region and applies any further processing here - for instance, it may handle
It uses `accumulator` to concatenate the chunks
- the `done_exchange` region (invoked once when communication has finished) that takes the concatenated
chunk of the `receive_chunk` region and applies any further processing here - for instance, it may handle
the computation of 'own' (non-communicated) or otherwise prefetched data
Further fields:
- `communicated_stencil` - the stencil to communicate (send and receive)
- `field` - the stencil field to communicate (send and receive)
- `args` - arguments to the stencil computation, may include other prefetched buffers
- `topo` - as received from `csl_stencil.prefetch`/`dmp.swap`
- `num_chunks` - number of chunks into which to slice the communication
Expand All @@ -187,32 +187,32 @@ class ApplyOp(IRDLOperation):
Function signatures:
Before lowering (from `csl_stencil.prefetch` and `stencil.apply`):
%pref = csl_stencil.prefetch(%communicated_stencil : stencil.Temp)
stencil.apply( ..some args.. , %communicated_stencil, ..some more args.., %pref)
%pref = csl_stencil.prefetch(%field : stencil.Temp)
stencil.apply( ..some args.. , %field, ..some more args.., %pref)
After lowering:
op: csl_stencil.apply(%communicated_stencil, %iter_arg, chunk_reduce_args..., post_process_args...)
chunk_reduce: block_args(slice of type(%pref), %offset, %iter_arg, args...)
post_process: block_args(%communicated_stencil, %iter_arg, args...)
op: csl_stencil.apply(%field, %accumulator, receive_chunk_args..., done_exchange_args...)
receive_chunk: block_args(slice of type(%pref), %offset, %accumulator, args...)
done_exchange: block_args(%field, %accumulator, args...)
Note, that %pref can be dropped (as communication is done by the op rather than before the op),
and that a new %iter_arg is required, an empty tensor which is filled by `chunk_reduce` and
consumed by `post_process`
and that a new %accumulator is required, an empty tensor which is filled by `receive_chunk` and
consumed by `done_exchange`
"""

name = "csl_stencil.apply"

communicated_stencil = operand_def(
field = operand_def(
base(stencil.StencilType[Attribute]) | base(memref.MemRefType[Attribute])
)

iter_arg = operand_def(TensorType[Attribute] | memref.MemRefType[Attribute])
accumulator = operand_def(TensorType[Attribute] | memref.MemRefType[Attribute])

args = var_operand_def(Attribute)
dest = var_operand_def(stencil.FieldType | memref.MemRefType[Attribute])

chunk_reduce = region_def()
post_process = region_def()
receive_chunk = region_def()
done_exchange = region_def()

swaps = prop_def(builtin.ArrayAttr[ExchangeDeclarationAttr])

Expand Down Expand Up @@ -243,7 +243,7 @@ def print_arg(arg: SSAValue):
printer.print("(")

# args required by function signature, plus optional args for regions
args = [self.communicated_stencil, self.iter_arg, *self.args]
args = [self.field, self.accumulator, *self.args]

printer.print_list(args, print_arg)
if self.dest:
Expand All @@ -259,9 +259,9 @@ def print_arg(arg: SSAValue):
printer.print("> ")
printer.print_op_attributes(self.attributes, print_keyword=True)
printer.print("(")
printer.print_region(self.chunk_reduce, print_entry_block_args=True)
printer.print_region(self.receive_chunk, print_entry_block_args=True)
printer.print(", ")
printer.print_region(self.post_process, print_entry_block_args=True)
printer.print_region(self.done_exchange, print_entry_block_args=True)
printer.print(")")
if self.bounds is not None:
printer.print(" to ")
Expand Down Expand Up @@ -298,9 +298,9 @@ def parse_args():
if attrs is not None:
attrs = attrs.data
parser.parse_punctuation("(")
chunk_reduce = parser.parse_region()
receive_chunk = parser.parse_region()
parser.parse_punctuation(",")
post_process = parser.parse_region()
done_exchange = parser.parse_region()
parser.parse_punctuation(")")
if parser.parse_optional_keyword("to"):
props["bounds"] = stencil.StencilBoundsAttr.new(
Expand All @@ -309,23 +309,23 @@ def parse_args():
return cls(
operands=[operands[0], operands[1], operands[2:], destinations],
result_types=[result_types],
regions=[chunk_reduce, post_process],
regions=[receive_chunk, done_exchange],
properties=props,
attributes=attrs,
)

def verify_(self) -> None:
# typecheck op arguments
if (
len(self.chunk_reduce.block.args) < 3
or len(self.post_process.block.args) < 2
len(self.receive_chunk.block.args) < 3
or len(self.done_exchange.block.args) < 2
):
raise VerifyException("Missing required block args on region")
op_args = (
self.post_process.block.args[0],
self.chunk_reduce.block.args[2],
*self.chunk_reduce.block.args[3:],
*self.post_process.block.args[2:],
self.done_exchange.block.args[0],
self.receive_chunk.block.args[2],
*self.receive_chunk.block.args[3:],
*self.done_exchange.block.args[2:],
)
for operand, argument in zip(self.operands, op_args):
if operand.type != argument.type:
Expand All @@ -335,36 +335,36 @@ def verify_(self) -> None:

# typecheck required (only) block arguments
assert isa(
self.iter_arg.type, TensorType[Attribute] | memref.MemRefType[Attribute]
self.accumulator.type, TensorType[Attribute] | memref.MemRefType[Attribute]
)
chunk_reduce_req_types = [
type(self.iter_arg.type)(
self.iter_arg.type.get_element_type(),
chunk_region_req_types = [
type(self.accumulator.type)(
self.accumulator.type.get_element_type(),
(
len(self.swaps),
self.iter_arg.type.get_shape()[0] // self.num_chunks.value.data,
self.accumulator.type.get_shape()[0] // self.num_chunks.value.data,
),
),
IndexType(),
self.iter_arg.type,
self.accumulator.type,
]
post_process_req_types = [
self.communicated_stencil.type,
self.iter_arg.type,
done_exchange_req_types = [
self.field.type,
self.accumulator.type,
]
for arg, expected_type in zip(
self.chunk_reduce.block.args, chunk_reduce_req_types
self.receive_chunk.block.args, chunk_region_req_types
):
if arg.type != expected_type:
raise VerifyException(
f"Unexpected block argument type of chunk_reduce, got {arg.type} != {expected_type} at index {arg.index}"
f"Unexpected block argument type of receive_chunk, got {arg.type} != {expected_type} at index {arg.index}"
)
for arg, expected_type in zip(
self.post_process.block.args, post_process_req_types
self.done_exchange.block.args, done_exchange_req_types
):
if arg.type != expected_type:
raise VerifyException(
f"Unexpected block argument type of post_process, got {arg.type} != {expected_type} at index {arg.index}"
f"Unexpected block argument type of done_exchange, got {arg.type} != {expected_type} at index {arg.index}"
)

if (len(self.res) == 0) == (len(self.dest) == 0):
Expand Down Expand Up @@ -393,7 +393,7 @@ def get_accesses(self) -> Iterable[stencil.AccessPattern]:
field of the apply operation.
"""
# iterate over the block arguments
for arg in self.chunk_reduce.block.args + self.post_process.block.args:
for arg in self.receive_chunk.block.args + self.done_exchange.block.args:
accesses: list[tuple[int, ...]] = []
# walk the uses of the argument
for use in arg.uses:
Expand Down
Loading

0 comments on commit 1606a5c

Please sign in to comment.