diff --git a/tests/test_traits.py b/tests/test_traits.py index 39191add09..b17381f3b3 100644 --- a/tests/test_traits.py +++ b/tests/test_traits.py @@ -12,16 +12,26 @@ from xdsl.dialects import test from xdsl.dialects.builtin import ( + DYNAMIC_INDEX, AnyIntegerAttr, + AnyMemRefTypeConstr, + AnyTensorTypeConstr, + AnyUnrankedMemrefTypeConstr, + AnyUnrankedTensorTypeConstr, IntegerAttr, IntegerType, + MemRefType, + NoneAttr, StringAttr, SymbolRefAttr, + TensorType, + UnrankedTensorType, i1, i32, i64, ) -from xdsl.ir import Operation, OpTrait, OpTraits +from xdsl.ir import Attribute, Operation, OpTrait, OpTraits +from xdsl.ir.core import SSAValue from xdsl.irdl import ( Block, IRDLOperation, @@ -37,6 +47,7 @@ result_def, traits_def, ) +from xdsl.irdl.operations import var_operand_def, var_result_def from xdsl.traits import ( AlwaysSpeculatable, ConditionallySpeculatable, @@ -44,6 +55,7 @@ HasParent, OptionalSymbolOpInterface, RecursivelySpeculatable, + SameOperandsAndResultType, SymbolOpInterface, SymbolTable, is_speculatable, @@ -549,6 +561,440 @@ class SupeculatabilityTestOp(IRDLOperation): assert is_speculatable(op) is speculatability +@pytest.mark.parametrize( + ("operands", "result_types"), + [ + ([()], [()]), + ([()], (test.TestType("foo"),)), + ((TestSSAValue(test.TestType("foo")),), [()]), + ], +) +def test_same_operands_and_result_type_trait_for_scalar_types( + operands: tuple[SSAValue] | tuple[()], + result_types: tuple[test.TestType] | tuple[()], +): + @irdl_op_definition + class SameOperandsAndResultTypeOp(IRDLOperation): + name = "test.same_operand_and_result_type" + + ops = var_operand_def(test.TestType("foo")) + res = var_result_def(test.TestType("foo")) + + traits = traits_def(SameOperandsAndResultType()) + + op = SameOperandsAndResultTypeOp(operands=operands, result_types=result_types) + + with pytest.raises( + VerifyException, match="requires at least one result or operand" + ): + op.verify() + + +@irdl_op_definition +class SameOperandsAndResultTypeOp(IRDLOperation): + name = "test.same_operand_and_result_type" + + ops = var_operand_def( + AnyMemRefTypeConstr + | AnyUnrankedMemrefTypeConstr + | AnyUnrankedTensorTypeConstr + | AnyTensorTypeConstr + ) + + res = var_result_def( + AnyMemRefTypeConstr + | AnyUnrankedMemrefTypeConstr + | AnyUnrankedTensorTypeConstr + | AnyTensorTypeConstr + ) + + traits = traits_def(SameOperandsAndResultType()) + + +@pytest.mark.parametrize( + ( + "operand1_and_result_element_type", + "operand_and_result_shape1", + "result_element_type2", + "result_shape2", + ), + [ + ( + test.TestType("foo"), + [2, 3], + test.TestType("qux"), + [2, 3], + ), + ( + test.TestType("foo"), + [2, 3], + test.TestType("foo"), + [2, 4], + ), + ( + test.TestType("qux"), + [2, 3], + test.TestType("foo"), + [2, 3], + ), + ( + test.TestType("foo"), + [2, 4], + test.TestType("foo"), + [2, 3], + ), + ], +) +def test_same_operands_and_result_type_trait_for_result_element_type_of_shaped_types( + operand1_and_result_element_type: Attribute, + operand_and_result_shape1: tuple[int], + result_element_type2: Attribute, + result_shape2: tuple[int], +): + op = SameOperandsAndResultTypeOp( + operands=[ + TestSSAValue( + TensorType(operand1_and_result_element_type, operand_and_result_shape1) + ) + ], + result_types=[ + [ + TensorType(operand1_and_result_element_type, operand_and_result_shape1), + TensorType(result_element_type2, result_shape2), + ], + ], + ) + + with pytest.raises( + VerifyException, + match="requires the same type for all operands and results", + ): + op.verify() + + +@pytest.mark.parametrize( + "operands_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + "results_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + ( + "operand_element_type", + "operand_shape", + "result_element_type", + "result_shape", + ), + [ + ( + test.TestType("foo"), + [2, 3], + test.TestType("qux"), + [2, 3], + ), + ( + test.TestType("foo"), + [2, 3], + test.TestType("foo"), + [2, 4], + ), + ( + test.TestType("qux"), + [2, 3], + test.TestType("foo"), + [2, 3], + ), + ( + test.TestType("foo"), + [2, 4], + test.TestType("foo"), + [2, 3], + ), + ], +) +def test_same_operands_and_result_type_trait_for_element_type_of_shaped_types( + operand_element_type: Attribute, + operand_shape: tuple[int], + result_element_type: Attribute, + result_shape: tuple[int], + operands_num: int, + results_num: int, +): + op = SameOperandsAndResultTypeOp( + operands=[ + [ + TestSSAValue(TensorType(operand_element_type, operand_shape)), + ] + * operands_num, + ], + result_types=[[TensorType(result_element_type, result_shape)] * results_num], + ) + + with pytest.raises( + VerifyException, + match="requires the same type for all operands and results", + ): + op.verify() + + op = SameOperandsAndResultTypeOp( + operands=[ + [ + TestSSAValue(MemRefType(operand_element_type, operand_shape)), + ] + * operands_num, + ], + result_types=[[MemRefType(result_element_type, result_shape)] * results_num], + ) + + with pytest.raises( + VerifyException, + match="requires the same type for all operands and results", + ): + op.verify() + + +@pytest.mark.parametrize( + ( + "element_type", + "shape", + "operand1_and_result_encoding", + "result_encoding2", + ), + [ + ( + test.TestType("foo"), + [2, 3], + StringAttr("bar"), + StringAttr("baz"), + ), + ( + test.TestType("foo"), + [2, 3], + StringAttr("baz"), + StringAttr("bar"), + ), + ( + test.TestType("foo"), + [2, 3], + StringAttr("bar"), + NoneAttr(), + ), + ( + test.TestType("foo"), + [2, 3], + NoneAttr(), + StringAttr("bar"), + ), + ], +) +def test_same_operands_and_result_type_trait_for_result_encoding_of_shaped_types( + element_type: Attribute, + shape: tuple[int], + operand1_and_result_encoding: Attribute, + result_encoding2: Attribute, +): + op = SameOperandsAndResultTypeOp( + operands=[ + [ + TestSSAValue( + TensorType( + element_type, + shape, + operand1_and_result_encoding, + ) + ), + ], + ], + result_types=[ + [ + TensorType(element_type, shape, operand1_and_result_encoding), + TensorType(element_type, shape, result_encoding2), + ] + ], + ) + + with pytest.raises( + VerifyException, + match="requires the same encoding for all operands and results", + ): + op.verify() + + +@pytest.mark.parametrize( + "operands_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + "results_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + ( + "element_type", + "shape", + "operand_encoding", + "result_encoding", + ), + [ + ( + test.TestType("foo"), + [2, 3], + StringAttr("bar"), + StringAttr("baz"), + ), + ( + test.TestType("foo"), + [2, 3], + StringAttr("baz"), + StringAttr("bar"), + ), + ( + test.TestType("foo"), + [2, 3], + StringAttr("bar"), + NoneAttr(), + ), + ( + test.TestType("foo"), + [2, 3], + NoneAttr(), + StringAttr("bar"), + ), + ], +) +def test_same_operands_and_result_type_trait_for_encoding_of_shaped_types( + element_type: Attribute, + shape: tuple[int], + operand_encoding: Attribute, + result_encoding: Attribute, + operands_num: int, + results_num: int, +): + op = SameOperandsAndResultTypeOp( + operands=[ + [ + TestSSAValue( + TensorType( + element_type, + shape, + operand_encoding, + ) + ), + ] + * operands_num, + ], + result_types=[[TensorType(element_type, shape, result_encoding)] * results_num], + ) + + with pytest.raises( + VerifyException, + match="requires the same encoding for all operands and results", + ): + op.verify() + + +@pytest.mark.parametrize( + "operands_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + "results_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + ( + "operand1_and_result_shape", + "operand2_shape", + ), + [ + ( + [1], + [1], + ), + ( + [2, 3], + [2, 3], + ), + ( + [2, 3], + [2, DYNAMIC_INDEX], + ), + ( + [2, 4], + [2, DYNAMIC_INDEX], + ), + ( + [2, DYNAMIC_INDEX], + [2, DYNAMIC_INDEX], + ), + ], +) +def test_same_operands_and_result_type_trait_for_ranked_mixed_shapes( + operand1_and_result_shape: tuple[int], + operand2_shape: tuple[int], + operands_num: int, + results_num: int, +): + op = SameOperandsAndResultTypeOp( + operands=[ + [ + TestSSAValue( + TensorType(test.TestType("foo"), operand1_and_result_shape) + ), + TestSSAValue(TensorType(test.TestType("foo"), operand2_shape)), + ] + * operands_num, + ], + result_types=[ + [TensorType(test.TestType("foo"), operand1_and_result_shape)] * results_num + ], + ) + + op.verify() + + +@pytest.mark.parametrize( + "operands_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + "results_num", + [1, 2, 3], +) +@pytest.mark.parametrize( + ("operand1_and_result_shape",), + [ + ([1],), + ([2, 3],), + ([2, DYNAMIC_INDEX],), + ([DYNAMIC_INDEX, DYNAMIC_INDEX],), + ], +) +def test_same_operands_and_result_type_trait_for_mixed_rank_and_mixed_shapes( + operand1_and_result_shape: tuple[int], + operands_num: int, + results_num: int, +): + op = SameOperandsAndResultTypeOp( + operands=[ + [ + TestSSAValue( + TensorType(test.TestType("foo"), operand1_and_result_shape) + ), + TestSSAValue(UnrankedTensorType(test.TestType("foo"))), + ] + * operands_num, + ], + result_types=[ + [TensorType(test.TestType("foo"), operand1_and_result_shape)] * results_num + ], + ) + + op.verify() + + @irdl_op_definition class TestModifyTraitsOp(IRDLOperation): name = "test.test_modify_traits" diff --git a/tests/utils/test_type.py b/tests/utils/test_type.py new file mode 100644 index 0000000000..bbd0359837 --- /dev/null +++ b/tests/utils/test_type.py @@ -0,0 +1,61 @@ +from xdsl.dialects import test +from xdsl.dialects.builtin import DYNAMIC_INDEX, TensorType, UnrankedTensorType +from xdsl.utils.type import get_element_type_or_self, have_compatible_shape + + +def test_get_element_type_or_self(): + scalar_type1 = test.TestType("foo") + assert scalar_type1 == get_element_type_or_self(scalar_type1) + + shaped_type1 = TensorType(scalar_type1, [4]) + assert scalar_type1 == get_element_type_or_self(shaped_type1) + + unranked_shaped_type1 = UnrankedTensorType(scalar_type1) + assert scalar_type1 == get_element_type_or_self(unranked_shaped_type1) + + +def test_have_compatible_shape(): + scalar_type1 = test.TestType("foo") + scalar_type2 = test.TestType("foo") + + assert have_compatible_shape(scalar_type1, scalar_type2) + + shaped_type1 = TensorType(scalar_type1, [4]) + + assert not have_compatible_shape(scalar_type1, shaped_type1) + assert not have_compatible_shape(shaped_type1, scalar_type1) + + unranked_shaped_type1 = UnrankedTensorType(scalar_type1) + unranked_shaped_type2 = UnrankedTensorType(scalar_type2) + + assert have_compatible_shape(shaped_type1, unranked_shaped_type1) + assert have_compatible_shape(unranked_shaped_type1, shaped_type1) + assert have_compatible_shape(unranked_shaped_type1, unranked_shaped_type2) + + shaped_type2 = TensorType(scalar_type2, [1, 2]) + + assert not have_compatible_shape(shaped_type1, shaped_type2) + + shaped_type3 = TensorType(scalar_type2, [5]) + + assert not have_compatible_shape(shaped_type1, shaped_type3) + + shaped_type4 = TensorType(scalar_type2, [1, 3]) + + assert not have_compatible_shape(shaped_type2, shaped_type4) + + shaped_type5 = TensorType(scalar_type2, [DYNAMIC_INDEX, 3]) + shaped_type6 = TensorType(scalar_type2, [1, DYNAMIC_INDEX]) + shaped_type7 = TensorType(scalar_type2, [DYNAMIC_INDEX, DYNAMIC_INDEX]) + + assert have_compatible_shape(shaped_type4, shaped_type5) + assert have_compatible_shape(shaped_type4, shaped_type6) + + assert have_compatible_shape(shaped_type5, shaped_type6) + assert have_compatible_shape(shaped_type5, shaped_type7) + + assert have_compatible_shape(shaped_type6, shaped_type7) + + shaped_type8 = TensorType(scalar_type2, [2, DYNAMIC_INDEX]) + + assert not have_compatible_shape(shaped_type4, shaped_type8) diff --git a/xdsl/traits.py b/xdsl/traits.py index 23eca97fcd..823f18220a 100644 --- a/xdsl/traits.py +++ b/xdsl/traits.py @@ -729,3 +729,60 @@ def get_insn(self, op: Operation) -> str: Return the insn representation of the operation for printing. """ raise NotImplementedError() + + +@dataclass(frozen=True) +class SameOperandsAndResultType(OpTrait): + """Constrain the operation to have the same operands and result type.""" + + def verify(self, op: Operation) -> None: + from xdsl.dialects.builtin import NoneAttr, TensorType + from xdsl.utils.type import get_element_type_or_self, have_compatible_shape + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException( + f"'{op.name}' requires at least one result or operand" + ) + + res_type0 = get_element_type_or_self(op.result_types[0]) + + def get_encoding(maybe_shaped_type: Attribute) -> Attribute: + if isinstance(maybe_shaped_type, TensorType): + return maybe_shaped_type.encoding + return NoneAttr() + + encoding = get_encoding(op.result_types[0]) + + for res_type in op.result_types[1:]: + res_type_elem = get_element_type_or_self(res_type) + if res_type0 != res_type_elem or not have_compatible_shape( + op.result_types[0], res_type + ): + raise VerifyException( + f"'{op.name} requires the same type for all operands and results" + ) + + elem_encoding = get_encoding(res_type) + + if encoding != elem_encoding: + raise VerifyException( + f"'{op.name} requires the same encoding for all operands and results" + ) + + for oprnd_type in op.operand_types: + oprnd_type_elem = get_element_type_or_self(oprnd_type) + if res_type0 != oprnd_type_elem or not have_compatible_shape( + op.result_types[0], oprnd_type + ): + raise VerifyException( + f"'{op.name} requires the same type for all operands and results" + ) + + elem_encoding = NoneAttr() + if isinstance(oprnd_type, TensorType): + elem_encoding = oprnd_type.encoding + + if encoding != elem_encoding: + raise VerifyException( + f"'{op.name} requires the same encoding for all operands and results" + ) diff --git a/xdsl/utils/type.py b/xdsl/utils/type.py new file mode 100644 index 0000000000..2c796ea911 --- /dev/null +++ b/xdsl/utils/type.py @@ -0,0 +1,39 @@ +"""""" + +from typing import Any, cast + +from xdsl.dialects.builtin import DYNAMIC_INDEX, ContainerType, ShapedType +from xdsl.ir import Attribute + + +def get_element_type_or_self(maybe_shaped_type: Attribute) -> Attribute: + if isinstance(maybe_shaped_type, ContainerType): + container_type = cast(ContainerType[Any], maybe_shaped_type) + return container_type.get_element_type() + return maybe_shaped_type + + +def have_compatible_shape(lhs_type: Attribute, rhs_type: Attribute) -> bool: + is_lhs_shaped = isinstance(lhs_type, ContainerType) + is_rhs_shaped = isinstance(rhs_type, ContainerType) + + # both are scalars + if not is_lhs_shaped and not is_rhs_shaped: + return True + + # one is scalar and the other shaped + if (is_lhs_shaped and not is_rhs_shaped) or (not is_lhs_shaped and is_rhs_shaped): + return False + + # at least one is unranked + if not isinstance(lhs_type, ShapedType) or not isinstance(rhs_type, ShapedType): + return True + + # both ranked, so check ranks + if lhs_type.get_num_dims() != rhs_type.get_num_dims(): + return False + + return all( + dim1 == DYNAMIC_INDEX or dim2 == DYNAMIC_INDEX or dim1 == dim2 + for dim1, dim2 in zip(lhs_type.get_shape(), rhs_type.get_shape()) + )