Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transforms: (memref-to-dsd) Support 1d subview of nd memref #3653

Merged
merged 11 commits into from
Jan 6, 2025
12 changes: 12 additions & 0 deletions tests/filecheck/transforms/memref-to-dsd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ builtin.module {
// CHECK-NEXT: %31 = memref.load %b[%13] : memref<510xf32>
// CHECK-NEXT: "test.op"(%31) : (f32) -> ()

%39 = memref.alloc() {"alignment" = 64 : i64} : memref<3x64xf32>
%40 = "memref.subview"(%39, %0) <{"operandSegmentSizes" = array<i32: 1, 1, 0, 0>, "static_offsets" = array<i64: 2, -9223372036854775808>, "static_sizes" = array<i64: 1, 32>, "static_strides" = array<i64: 1, 1>}> : (memref<3x64xf32>, index) -> memref<32xf32, strided<[1], offset: ?>>

// CHECK-NEXT: %32 = "csl.zeros"() : () -> memref<3x64xf32>
// CHECK-NEXT: %33 = arith.constant 3 : i16
// CHECK-NEXT: %34 = arith.constant 64 : i16
// CHECK-NEXT: %35 = "csl.get_mem_dsd"(%32, %33, %34) : (memref<3x64xf32>, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %36 = arith.constant 32 : i16
// CHECK-NEXT: %37 = "csl.get_mem_dsd"(%32, %36) <{"tensor_access" = affine_map<(d0) -> (2, d0)>}> : (memref<3x64xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %38 = arith.index_cast %0 : index to si16
// CHECK-NEXT: %39 = "csl.increment_dsd_offset"(%37, %38) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>

}) {sym_name = "program"} : () -> ()
}
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
Expand Down
15 changes: 8 additions & 7 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,25 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
# ensure we send only core data
assert isa(op.accumulator.type, memref.MemRefType[Attribute])
assert isa(op.field.type, memref.MemRefType[Attribute])
# the accumulator might have additional dims when used for holding prefetched data
send_buf_shape = op.accumulator.type.get_shape()[
-len(op.field.type.get_shape()) :
]
n-io marked this conversation as resolved.
Show resolved Hide resolved
send_buf = memref.SubviewOp.get(
op.field,
[
(d - s) // 2 # symmetric offset
for s, d in zip(
op.accumulator.type.get_shape(), op.field.type.get_shape()
)
for s, d in zip(send_buf_shape, op.field.type.get_shape(), strict=True)
],
op.accumulator.type.get_shape(),
len(op.accumulator.type.get_shape()) * [1],
op.accumulator.type,
send_buf_shape,
len(send_buf_shape) * [1],
memref.MemRefType(op.field.type.get_element_type(), send_buf_shape),
)

# add api call
num_chunks = arith.ConstantOp(IntegerAttr(op.num_chunks.value, i16))
chunk_ref = csl.AddressOfFnOp(chunk_fn)
done_ref = csl.AddressOfFnOp(done_fn)
# send_buf = memref.Subview.get(op.field, [], op.accumulator.type.get_shape(), )
api_call = csl.MemberCallOp(
"communicate",
None,
Expand Down
48 changes: 46 additions & 2 deletions xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections
from collections.abc import Sequence
from dataclasses import dataclass
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, csl, memref
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyMemRefType,
ArrayAttr,
Float16Type,
Float32Type,
Expand All @@ -19,6 +22,7 @@
UnrealizedConversionCastOp,
)
from xdsl.ir import Attribute, Operation, OpResult, SSAValue
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -119,7 +123,47 @@ class LowerSubviewOpPass(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
assert isa(op.source.type, MemRefType[Attribute])
assert isa(op.source.type, AnyMemRefType)
assert isa(op.result.type, AnyMemRefType)

if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1:
# 1d subview onto a nd memref
sizes = op.static_sizes.get_values()
counter_sizes = collections.Counter(sizes)
n-io marked this conversation as resolved.
Show resolved Hide resolved
counter_sizes.pop(1, None)
assert (
len(counter_sizes) == 1
), "1d access into nd memref must specify one size > 1"
size, size_count = counter_sizes.most_common()[0]
size = cast(int, size)

assert (
size_count == 1
), "1d access into nd memref can only specify one size > 1, which can occur only once"
assert all(
stride == 1 for stride in op.static_strides.get_values()
), "All strides must equal 1"

amap: list[AffineExpr] = [
AffineConstantExpr(
cast(int, o) if o != memref.SubviewOp.DYNAMIC_INDEX else 0
)
for o in op.static_offsets.get_values()
]
amap[sizes.index(size)] += AffineDimExpr(0)

size_op = arith.ConstantOp.from_int_and_width(size, 16)
dsd_op = csl.GetMemDsdOp(
operands=[op.source, [size_op]],
properties={
"tensor_access": AffineMapAttr(AffineMap(1, 0, tuple(amap)))
},
result_types=[csl.DsdType(csl.DsdKind.mem1d_dsd)],
)
offset_ops = self._update_offsets(op, dsd_op) if op.offsets else []
rewriter.replace_matched_op([size_op, dsd_op, *offset_ops])
return

assert len(op.static_sizes) == 1, "not implemented"
assert len(op.static_offsets) == 1, "not implemented"
assert len(op.static_strides) == 1, "not implemented"
Expand Down Expand Up @@ -219,7 +263,7 @@ def _update_offsets(

static_offsets = cast(Sequence[int], subview.static_offsets.get_values())

if static_offsets[0] == memref.SubviewOp.DYNAMIC_INDEX:
if subview.offsets:
ops.append(cast_op := arith.IndexCastOp(subview.offsets[0], csl.i16_value))
ops.append(
csl.IncrementDsdOffsetOp.build(
Expand Down
Loading