Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (vector) Add for vector.transfer_read and vector.transfer_write operations #3650

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
383fc93
dialects: (vector) Added stubs for transfer_read and transfer_write
watermelonwolverine Dec 17, 2024
85a7c83
Started working on verification of "vector.transfer_read" and "vector…
watermelonwolverine Dec 18, 2024
c8ec10f
Continued working on verification of "vector.transfer_read" and "vect…
watermelonwolverine Dec 18, 2024
ff8b9f9
Did some follow-up changes that resulted from VectorType expecting a …
watermelonwolverine Dec 18, 2024
9c87737
Reverted changes done to VectorType.
watermelonwolverine Dec 19, 2024
77d8acb
Fixed bugs for vector.transfer_*
watermelonwolverine Dec 19, 2024
b7f554c
Did some refactoring in affine_map and affine_expr
watermelonwolverine Dec 19, 2024
0013c20
Added tests for newly added functions in affine_map.py and affine_exp…
watermelonwolverine Dec 19, 2024
9217d31
Added unused_dims_bit_vector function to AffineMap
watermelonwolverine Dec 19, 2024
b2abf0a
Merge remote-tracking branch 'origin/main' into vector_add_transfer_r…
watermelonwolverine Dec 19, 2024
a78a39b
Fixed formatting
watermelonwolverine Dec 19, 2024
7874eb6
Transferred a bunch of lit tests for vector.transfer_* ops from MLIR
watermelonwolverine Dec 19, 2024
47c55e9
Fixed bug in TransferWriteOp.__init__ and transfered more lit tests o…
watermelonwolverine Dec 19, 2024
f2c6f15
Fixed formatting
watermelonwolverine Dec 19, 2024
5b6c700
Fixed a few issues with vector.transfer_read and vector.transfer_write
watermelonwolverine Dec 28, 2024
a702250
Fixed vector.tranfer_* filechecks
watermelonwolverine Dec 28, 2024
0ff1569
Merge remote-tracking branch 'origin/main' into vector_add_transfer_r…
watermelonwolverine Jan 2, 2025
b331962
Perfomed some clean-up and undid all changes which aren't ready yet a…
watermelonwolverine Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/filecheck/dialects/vector/vector_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: XDSL_ROUNDTRIP

#map = affine_map<(d0, d1) -> (d0)>
builtin.module {
func.func private @vector_test(%0 : memref<4x4xindex>, %1 : vector<1xi1>, %2 : index) {
%3 = "vector.load"(%0, %2, %2) : (memref<4x4xindex>, index, index) -> vector<2xindex>
Expand All @@ -10,6 +10,9 @@ builtin.module {
"vector.maskedstore"(%0, %2, %2, %1, %6) : (memref<4x4xindex>, index, index, vector<1xi1>, vector<1xindex>) -> ()
"vector.print"(%6) : (vector<1xindex>) -> ()
%7 = "vector.create_mask"(%2) : (index) -> vector<2xi1>
%8 = "vector.transfer_read"(%0, %2, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 2, 1, 0>, "permutation_map" = #map}> : (memref<4x4xindex>, index, index, index) -> vector<4xindex>
"vector.transfer_write"(%8, %0, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 1, 2, 0>, "permutation_map" = #map}> : (vector<4xindex>, memref<4x4xindex>, index, index) -> ()

func.return
}
}
Expand All @@ -25,6 +28,8 @@ builtin.module {
// CHECK-NEXT: "vector.maskedstore"(%0, %2, %2, %1, %6) : (memref<4x4xindex>, index, index, vector<1xi1>, vector<1xindex>) -> ()
// CHECK-NEXT: "vector.print"(%6) : (vector<1xindex>) -> ()
// CHECK-NEXT: %7 = "vector.create_mask"(%2) : (index) -> vector<2xi1>
// CHECK-NEXT: %8 = "vector.transfer_read"(%0, %2, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 2, 1, 0>, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref<4x4xindex>, index, index, index) -> vector<4xindex>
// CHECK-NEXT: "vector.transfer_write"(%8, %0, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array<i32: 1, 1, 2, 0>, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<4xindex>, memref<4x4xindex>, index, index) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
4 changes: 2 additions & 2 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,8 +1891,8 @@ def get_resolvers(
}

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if isinstance(attr, VectorType) or isinstance(attr, TensorType):
attr = cast(VectorType[Attribute] | TensorType[Attribute], attr)
if isinstance(attr, MemRefType) or isinstance(attr, TensorType):
attr = cast(MemRefType[Attribute] | TensorType[Attribute], attr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind moving this to a separate PR with a test case?
We seem to be lacking any verification testing on this type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing the changes, I'd still recommend doing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do that in the next few days

self.elem_constr.verify(attr.element_type, constraint_context)
else:
raise VerifyException(f"Expected tensor or memref type, got {attr}")
Expand Down
84 changes: 84 additions & 0 deletions xdsl/dialects/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
from collections.abc import Sequence

from xdsl.dialects.builtin import (
I1,
AffineMapAttr,
ArrayAttr,
BoolAttr,
IndexType,
MemRefType,
TensorOrMemrefOf,
TensorType,
VectorBaseTypeAndRankConstraint,
VectorBaseTypeConstraint,
VectorRankConstraint,
Expand All @@ -13,9 +19,13 @@
)
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AttrSizedOperandSegments,
IRDLOperation,
irdl_op_definition,
operand_def,
opt_operand_def,
opt_prop_def,
prop_def,
result_def,
traits_def,
var_operand_def,
Expand Down Expand Up @@ -292,6 +302,78 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp:
)


@irdl_op_definition
class TransferReadOp(IRDLOperation):
name = "vector.transfer_read"

source = operand_def(TensorOrMemrefOf(Attribute))
indices = var_operand_def(IndexType)
padding = operand_def(Attribute)
mask = opt_operand_def(VectorType[I1])

permutation_map = prop_def(AffineMapAttr)
in_bounds = opt_prop_def(ArrayAttr[BoolAttr])

result = result_def(VectorType)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def verify_(self):
assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute])
assert isa(self.result.type, VectorType[Attribute])
# TODO verify.

compor marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def get(
compor marked this conversation as resolved.
Show resolved Hide resolved
source: SSAValue | Operation,
indices: Sequence[SSAValue | Operation],
padding: SSAValue | Operation,
result_type: Attribute,
mask: Sequence[SSAValue | Operation] | None = None,
permutation_map: AffineMapAttr | None = None,
in_bounds: ArrayAttr[BoolAttr] | None = None,
):
return TransferReadOp.build(
operands=[source, indices, padding, mask],
result_types=[result_type],
properties={"permutation_map": permutation_map, "in_bounds": in_bounds},
)


@irdl_op_definition
class TransferWriteOp(IRDLOperation):
name = "vector.transfer_write"

vector = operand_def(VectorType[Attribute])
source = operand_def(TensorOrMemrefOf(Attribute))
indices = var_operand_def(IndexType)
mask = opt_operand_def(VectorType[I1])

in_bounds = prop_def(ArrayAttr[BoolAttr])
permutation_map = prop_def(AffineMapAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def verify_(self):
assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute])
assert isa(self.vector.type, VectorType[Attribute])
# TODO verify

@staticmethod
def get(
vector: SSAValue | Operation,
compor marked this conversation as resolved.
Show resolved Hide resolved
source: SSAValue | Operation,
indices: Sequence[SSAValue | Operation],
mask: Sequence[SSAValue | Operation] | None = None,
permutation_map: AffineMapAttr | None = None,
in_bounds: ArrayAttr[BoolAttr] | None = None,
):
return TransferWriteOp.build(
operands=[vector, source, indices, mask],
properties={"permutation_map": permutation_map, "in_bounds": in_bounds},
)


Vector = Dialect(
"vector",
[
Expand All @@ -303,6 +385,8 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp:
MaskedstoreOp,
PrintOp,
CreatemaskOp,
TransferReadOp,
TransferWriteOp,
],
[],
)
Loading