Skip to content

Commit

Permalink
dialects: (vector) Added stubs for transfer_read and transfer_write
Browse files Browse the repository at this point in the history
Fixed bug in TensorOrMemrefOf.verify (?)
  • Loading branch information
watermelonwolverine committed Dec 17, 2024
1 parent b8611e0 commit 383fc93
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 3 deletions.
7 changes: 6 additions & 1 deletion tests/filecheck/dialects/vector/vector_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: XDSL_ROUNDTRIP

#map = affine_map<(d0, d1) -> (d0)>
builtin.module {
func.func private @vector_test(%0 : memref<4x4xindex>, %1 : vector<1xi1>, %2 : index) {
%3 = "vector.load"(%0, %2, %2) : (memref<4x4xindex>, index, index) -> vector<2xindex>
Expand All @@ -10,6 +10,9 @@ builtin.module {
"vector.maskedstore"(%0, %2, %2, %1, %6) : (memref<4x4xindex>, index, index, vector<1xi1>, vector<1xindex>) -> ()
"vector.print"(%6) : (vector<1xindex>) -> ()
%7 = "vector.create_mask"(%2) : (index) -> vector<2xi1>
%8 = "vector.transfer_read"(%0, %2, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 2, 1, 0>, "permutation_map" = #map}> : (memref<4x4xindex>, index, index, index) -> vector<4xindex>
"vector.transfer_write"(%8, %0, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 1, 2, 0>, "permutation_map" = #map}> : (vector<4xindex>, memref<4x4xindex>, index, index) -> ()

func.return
}
}
Expand All @@ -25,6 +28,8 @@ builtin.module {
// CHECK-NEXT: "vector.maskedstore"(%0, %2, %2, %1, %6) : (memref<4x4xindex>, index, index, vector<1xi1>, vector<1xindex>) -> ()
// CHECK-NEXT: "vector.print"(%6) : (vector<1xindex>) -> ()
// CHECK-NEXT: %7 = "vector.create_mask"(%2) : (index) -> vector<2xi1>
// CHECK-NEXT: %8 = "vector.transfer_read"(%0, %2, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 2, 1, 0>, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref<4x4xindex>, index, index, index) -> vector<4xindex>
// CHECK-NEXT: "vector.transfer_write"(%8, %0, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 1, 2, 0>, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<4xindex>, memref<4x4xindex>, index, index) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
4 changes: 2 additions & 2 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,8 +1891,8 @@ def get_resolvers(
}

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if isinstance(attr, VectorType) or isinstance(attr, TensorType):
attr = cast(VectorType[Attribute] | TensorType[Attribute], attr)
if isinstance(attr, MemRefType) or isinstance(attr, TensorType):
attr = cast(MemRefType[Attribute] | TensorType[Attribute], attr)
self.elem_constr.verify(attr.element_type, constraint_context)
else:
raise VerifyException(f"Expected tensor or memref type, got {attr}")
Expand Down
84 changes: 84 additions & 0 deletions xdsl/dialects/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
from collections.abc import Sequence

from xdsl.dialects.builtin import (
I1,
AffineMapAttr,
ArrayAttr,
BoolAttr,
IndexType,
MemRefType,
TensorOrMemrefOf,
TensorType,
VectorBaseTypeAndRankConstraint,
VectorBaseTypeConstraint,
VectorRankConstraint,
Expand All @@ -13,9 +19,13 @@
)
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AttrSizedOperandSegments,
IRDLOperation,
irdl_op_definition,
operand_def,
opt_operand_def,
opt_prop_def,
prop_def,
result_def,
traits_def,
var_operand_def,
Expand Down Expand Up @@ -292,6 +302,78 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp:
)


@irdl_op_definition
class TransferReadOp(IRDLOperation):
name = "vector.transfer_read"

source = operand_def(TensorOrMemrefOf(Attribute))
indices = var_operand_def(IndexType)
padding = operand_def(Attribute)
mask = opt_operand_def(VectorType[I1])

permutation_map = prop_def(AffineMapAttr)
in_bounds = opt_prop_def(ArrayAttr[BoolAttr])

result = result_def(VectorType)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def verify_(self):
assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute])
assert isa(self.result.type, VectorType[Attribute])
# TODO verify.

@staticmethod
def get(
source: SSAValue | Operation,
indices: Sequence[SSAValue | Operation],
padding: SSAValue | Operation,
result_type: Attribute,
mask: Sequence[SSAValue | Operation] | None = None,
permutation_map: AffineMapAttr | None = None,
in_bounds: ArrayAttr[BoolAttr] | None = None,
):
return TransferReadOp.build(
operands=[source, indices, padding, mask],
result_types=[result_type],
properties={"permutation_map": permutation_map, "in_bounds": in_bounds},
)


@irdl_op_definition
class TransferWriteOp(IRDLOperation):
name = "vector.transfer_write"

vector = operand_def(VectorType[Attribute])
source = operand_def(TensorOrMemrefOf(Attribute))
indices = var_operand_def(IndexType)
mask = opt_operand_def(VectorType[I1])

in_bounds = prop_def(ArrayAttr[BoolAttr])
permutation_map = prop_def(AffineMapAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def verify_(self):
assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute])
assert isa(self.vector.type, VectorType[Attribute])
# TODO verify

@staticmethod
def get(
vector: SSAValue | Operation,
source: SSAValue | Operation,
indices: Sequence[SSAValue | Operation],
mask: Sequence[SSAValue | Operation] | None = None,
permutation_map: AffineMapAttr | None = None,
in_bounds: ArrayAttr[BoolAttr] | None = None,
):
return TransferWriteOp.build(
operands=[vector, source, indices, mask],
properties={"permutation_map": permutation_map, "in_bounds": in_bounds},
)


Vector = Dialect(
"vector",
[
Expand All @@ -303,6 +385,8 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp:
MaskedstoreOp,
PrintOp,
CreatemaskOp,
TransferReadOp,
TransferWriteOp,
],
[],
)

0 comments on commit 383fc93

Please sign in to comment.