Skip to content

Commit

Permalink
transforms: (memref-to-dsd) Support 1d subview of nd memref (#3653)
Browse files Browse the repository at this point in the history
Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Jan 6, 2025
1 parent 2ce4059 commit ad04c3f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 9 deletions.
12 changes: 12 additions & 0 deletions tests/filecheck/transforms/memref-to-dsd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ builtin.module {
// CHECK-NEXT: %31 = memref.load %b[%13] : memref<510xf32>
// CHECK-NEXT: "test.op"(%31) : (f32) -> ()

%39 = memref.alloc() {"alignment" = 64 : i64} : memref<3x64xf32>
%40 = "memref.subview"(%39, %0) <{"operandSegmentSizes" = array<i32: 1, 1, 0, 0>, "static_offsets" = array<i64: 2, -9223372036854775808>, "static_sizes" = array<i64: 1, 32>, "static_strides" = array<i64: 1, 1>}> : (memref<3x64xf32>, index) -> memref<32xf32, strided<[1], offset: ?>>

// CHECK-NEXT: %32 = "csl.zeros"() : () -> memref<3x64xf32>
// CHECK-NEXT: %33 = arith.constant 3 : i16
// CHECK-NEXT: %34 = arith.constant 64 : i16
// CHECK-NEXT: %35 = "csl.get_mem_dsd"(%32, %33, %34) : (memref<3x64xf32>, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %36 = arith.constant 32 : i16
// CHECK-NEXT: %37 = "csl.get_mem_dsd"(%32, %36) <{"tensor_access" = affine_map<(d0) -> (2, d0)>}> : (memref<3x64xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %38 = arith.index_cast %0 : index to si16
// CHECK-NEXT: %39 = "csl.increment_dsd_offset"(%37, %38) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>

}) {sym_name = "program"} : () -> ()
}
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
Expand Down
15 changes: 8 additions & 7 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,25 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
# ensure we send only core data
assert isa(op.accumulator.type, memref.MemRefType[Attribute])
assert isa(op.field.type, memref.MemRefType[Attribute])
# the accumulator might have additional dims when used for holding prefetched data
send_buf_shape = op.accumulator.type.get_shape()[
-len(op.field.type.get_shape()) :
]
send_buf = memref.SubviewOp.get(
op.field,
[
(d - s) // 2 # symmetric offset
for s, d in zip(
op.accumulator.type.get_shape(), op.field.type.get_shape()
)
for s, d in zip(send_buf_shape, op.field.type.get_shape(), strict=True)
],
op.accumulator.type.get_shape(),
len(op.accumulator.type.get_shape()) * [1],
op.accumulator.type,
send_buf_shape,
len(send_buf_shape) * [1],
memref.MemRefType(op.field.type.get_element_type(), send_buf_shape),
)

# add api call
num_chunks = arith.ConstantOp(IntegerAttr(op.num_chunks.value, i16))
chunk_ref = csl.AddressOfFnOp(chunk_fn)
done_ref = csl.AddressOfFnOp(done_fn)
# send_buf = memref.Subview.get(op.field, [], op.accumulator.type.get_shape(), )
api_call = csl.MemberCallOp(
"communicate",
None,
Expand Down
48 changes: 46 additions & 2 deletions xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections
from collections.abc import Sequence
from dataclasses import dataclass
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, csl, memref
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyMemRefType,
ArrayAttr,
Float16Type,
Float32Type,
Expand All @@ -19,6 +22,7 @@
UnrealizedConversionCastOp,
)
from xdsl.ir import Attribute, Operation, OpResult, SSAValue
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -119,7 +123,47 @@ class LowerSubviewOpPass(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
assert isa(op.source.type, MemRefType[Attribute])
assert isa(op.source.type, AnyMemRefType)
assert isa(op.result.type, AnyMemRefType)

if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1:
# 1d subview onto a nd memref
sizes = op.static_sizes.get_values()
counter_sizes = collections.Counter(sizes)
counter_sizes.pop(1, None)
assert (
len(counter_sizes) == 1
), "1d access into nd memref must specify one size > 1"
size, size_count = counter_sizes.most_common()[0]
size = cast(int, size)

assert (
size_count == 1
), "1d access into nd memref can only specify one size > 1, which can occur only once"
assert all(
stride == 1 for stride in op.static_strides.get_values()
), "All strides must equal 1"

amap: list[AffineExpr] = [
AffineConstantExpr(
cast(int, o) if o != memref.SubviewOp.DYNAMIC_INDEX else 0
)
for o in op.static_offsets.get_values()
]
amap[sizes.index(size)] += AffineDimExpr(0)

size_op = arith.ConstantOp.from_int_and_width(size, 16)
dsd_op = csl.GetMemDsdOp(
operands=[op.source, [size_op]],
properties={
"tensor_access": AffineMapAttr(AffineMap(1, 0, tuple(amap)))
},
result_types=[csl.DsdType(csl.DsdKind.mem1d_dsd)],
)
offset_ops = self._update_offsets(op, dsd_op) if op.offsets else []
rewriter.replace_matched_op([size_op, dsd_op, *offset_ops])
return

assert len(op.static_sizes) == 1, "not implemented"
assert len(op.static_offsets) == 1, "not implemented"
assert len(op.static_strides) == 1, "not implemented"
Expand Down Expand Up @@ -219,7 +263,7 @@ def _update_offsets(

static_offsets = cast(Sequence[int], subview.static_offsets.get_values())

if static_offsets[0] == memref.SubviewOp.DYNAMIC_INDEX:
if subview.offsets:
ops.append(cast_op := arith.IndexCastOp(subview.offsets[0], csl.i16_value))
ops.append(
csl.IncrementDsdOffsetOp.build(
Expand Down

0 comments on commit ad04c3f

Please sign in to comment.