From b8611e04fefa80ac2238f20311d28ab9b75aa4d6 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 17 Dec 2024 13:21:38 +0000 Subject: [PATCH] dialects: (arith) add SignlessIntegerBinaryOperation canonicalization (#3583) Generically implements various arith canonicalizations on SignlessIntegerBinaryOperation --- .../dialects/arith/canonicalize.mlir | 10 + tests/interactive/test_app.py | 10 +- xdsl/dialects/arith.py | 220 +++++++++++++++--- xdsl/traits.py | 6 + .../canonicalization_patterns/arith.py | 62 ++--- .../canonicalization_patterns/cf.py | 4 +- .../canonicalization_patterns/scf.py | 4 +- .../canonicalization_patterns/utils.py | 14 +- 8 files changed, 249 insertions(+), 81 deletions(-) diff --git a/tests/filecheck/dialects/arith/canonicalize.mlir b/tests/filecheck/dialects/arith/canonicalize.mlir index 2b49a6a847..8cc3984468 100644 --- a/tests/filecheck/dialects/arith/canonicalize.mlir +++ b/tests/filecheck/dialects/arith/canonicalize.mlir @@ -149,3 +149,13 @@ func.func @test_const_var_const() { %9 = arith.cmpi uge, %int, %int : i32 "test.op"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %int) : (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i32) -> () + +// Subtraction is not commutative so should not have the constant swapped to the right +// CHECK: arith.subi %c2, %a : i32 +%10 = arith.subi %c2, %a : i32 +"test.op"(%10) : (i32) -> () + +// CHECK: %{{.*}} = arith.constant false +%11 = arith.constant true +%12 = arith.addi %11, %11 : i1 +"test.op"(%12) : (i1) -> () diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index ac129f1c74..3e116aed3a 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -329,11 +329,11 @@ async def test_rewrites(): await pilot.click("#condense_button") addi_pass = AvailablePass( - display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:AddiIdentityRight", + display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight", module_pass=individual_rewrite.ApplyIndividualRewritePass, pass_spec=list( parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' + 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' ) )[0], ) @@ -354,7 +354,7 @@ async def test_rewrites(): individual_rewrite.ApplyIndividualRewritePass, list( parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' + 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' ) )[0], ), @@ -563,7 +563,7 @@ async def test_apply_individual_rewrite(): n.data is not None and n.data[1] is not None and str(n.data[1]) - == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiConstantProp"}' + == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationConstantProp"}' ): node = n @@ -593,7 +593,7 @@ async def test_apply_individual_rewrite(): n.data is not None and n.data[1] is not None and str(n.data[1]) - == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' + == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' ): node = n diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 768c051863..6850f33ffc 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -45,6 +45,7 @@ from xdsl.pattern_rewriter import RewritePattern from xdsl.printer import Printer from xdsl.traits import ( + Commutative, ConditionallySpeculatable, ConstantLike, HasCanonicalizationPatternsTrait, @@ -195,6 +196,36 @@ class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC): assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)" + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + """ + Performs a python function corresponding to this operation. + + If `i := py_operation(lhs, rhs)` is an int, then this operation can be + canonicalized to a constant with value `i` when the inputs are constants + with values `lhs` and `rhs`. + """ + return None + + @staticmethod + def is_right_zero(attr: AnyIntegerAttr) -> bool: + """ + Returns True only when 'attr' is a right zero for the operation + https://en.wikipedia.org/wiki/Absorbing_element + + Note that this depends on the operation and does *not* imply that + attr.value.data == 0 + """ + return False + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + """ + Return True only when 'attr' is a right unit/identity for the operation + https://en.wikipedia.org/wiki/Identity_element + """ + return False + def __init__( self, operand1: Operation | SSAValue, @@ -209,6 +240,22 @@ def __hash__(self) -> int: return id(self) +class SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait( + HasCanonicalizationPatternsTrait +): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.arith import ( + SignlessIntegerBinaryOperationConstantProp, + SignlessIntegerBinaryOperationZeroOrUnitRight, + ) + + return ( + SignlessIntegerBinaryOperationConstantProp(), + SignlessIntegerBinaryOperationZeroOrUnitRight(), + ) + + class SignlessIntegerBinaryOperationWithOverflow( SignlessIntegerBinaryOperation, abc.ABC ): @@ -318,22 +365,23 @@ def print(self, printer: Printer): printer.print_attribute(self.result.type) -class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): - @classmethod - def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: - from xdsl.transforms.canonicalization_patterns.arith import ( - AddiConstantProp, - AddiIdentityRight, - ) - - return (AddiIdentityRight(), AddiConstantProp()) - - @irdl_op_definition class AddiOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.addi" - traits = traits_def(Pure(), AddiOpHasCanonicalizationPatternsTrait()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs + rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -400,19 +448,27 @@ def infer_overflow_type(input_type: Attribute) -> Attribute: ) -class MuliHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait): - @classmethod - def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: - from xdsl.transforms.canonicalization_patterns import arith - - return (arith.MuliIdentityRight(), arith.MuliConstantProp()) - - @irdl_op_definition class MuliOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.muli" - traits = traits_def(Pure(), MuliHasCanonicalizationPatterns()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs * rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) + + @staticmethod + def is_right_zero(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 class MulExtendedBase(IRDLOperation): @@ -460,7 +516,17 @@ class MulSIExtendedOp(MulExtendedBase): class SubiOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.subi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs - rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 class DivUISpeculatable(ConditionallySpeculatable): @@ -483,7 +549,15 @@ class DivUIOp(SignlessIntegerBinaryOperation): name = "arith.divui" - traits = traits_def(NoMemoryEffect(), DivUISpeculatable()) + traits = traits_def( + NoMemoryEffect(), + DivUISpeculatable(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition @@ -495,7 +569,14 @@ class DivSIOp(SignlessIntegerBinaryOperation): name = "arith.divsi" - traits = traits_def(NoMemoryEffect()) + traits = traits_def( + NoMemoryEffect(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition @@ -506,21 +587,40 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation): name = "arith.floordivsi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition class CeilDivSIOp(SignlessIntegerBinaryOperation): name = "arith.ceildivsi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition class CeilDivUIOp(SignlessIntegerBinaryOperation): name = "arith.ceildivui" - traits = traits_def(NoMemoryEffect()) + traits = traits_def( + NoMemoryEffect(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition @@ -567,21 +667,57 @@ class MaxSIOp(SignlessIntegerBinaryOperation): class AndIOp(SignlessIntegerBinaryOperation): name = "arith.andi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs & rhs + + @staticmethod + def is_right_zero(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition class OrIOp(SignlessIntegerBinaryOperation): name = "arith.ori" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs | rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition class XOrIOp(SignlessIntegerBinaryOperation): name = "arith.xori" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs ^ rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -593,7 +729,13 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.shli" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -606,7 +748,13 @@ class ShRUIOp(SignlessIntegerBinaryOperation): name = "arith.shrui" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -620,7 +768,13 @@ class ShRSIOp(SignlessIntegerBinaryOperation): name = "arith.shrsi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 class ComparisonOperation(IRDLOperation): diff --git a/xdsl/traits.py b/xdsl/traits.py index 0b0e5383dd..9fd65519c9 100644 --- a/xdsl/traits.py +++ b/xdsl/traits.py @@ -687,6 +687,12 @@ class Pure(NoMemoryEffect, AlwaysSpeculatable): """ +class Commutative(OpTrait): + """ + A trait that signals that an operation is commutative. + """ + + class HasInsnRepresentation(OpTrait, abc.ABC): """ A trait providing information on how to encode an operation using a .insn assember directive. diff --git a/xdsl/transforms/canonicalization_patterns/arith.py b/xdsl/transforms/canonicalization_patterns/arith.py index d8005b3bf5..caa569c529 100644 --- a/xdsl/transforms/canonicalization_patterns/arith.py +++ b/xdsl/transforms/canonicalization_patterns/arith.py @@ -5,33 +5,46 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand +from xdsl.traits import Commutative +from xdsl.transforms.canonicalization_patterns.utils import ( + const_evaluate_operand, + const_evaluate_operand_attribute, +) from xdsl.utils.hints import isa -class AddiIdentityRight(RewritePattern): +class SignlessIntegerBinaryOperationZeroOrUnitRight(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter) -> None: - if (rhs := const_evaluate_operand(op.rhs)) is None: - return - if rhs != 0: + def match_and_rewrite( + self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, / + ): + if (rhs := const_evaluate_operand_attribute(op.rhs)) is None: return - rewriter.replace_matched_op((), (op.lhs,)) + if op.is_right_zero(rhs): + rewriter.replace_matched_op((), (op.rhs,)) + elif op.is_right_unit(rhs): + rewriter.replace_matched_op((), (op.lhs,)) -class AddiConstantProp(RewritePattern): +class SignlessIntegerBinaryOperationConstantProp(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter): + def match_and_rewrite( + self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, / + ): if (lhs := const_evaluate_operand(op.lhs)) is None: return if (rhs := const_evaluate_operand(op.rhs)) is None: # Swap inputs if lhs is constant and rhs is not - rewriter.replace_matched_op(arith.AddiOp(op.rhs, op.lhs)) + if op.has_trait(Commutative): + rewriter.replace_matched_op(op.__class__(op.rhs, op.lhs)) return + if (res := op.py_operation(lhs, rhs)) is None: + return assert isinstance(op.result.type, IntegerType | IndexType) + rewriter.replace_matched_op( - arith.ConstantOp.from_int_and_width(lhs + rhs, op.result.type) + arith.ConstantOp.from_int_and_width(res, op.result.type, truncate_bits=True) ) @@ -176,33 +189,6 @@ def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter): rewriter.replace_matched_op((), (op.lhs,)) -class MuliIdentityRight(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.MuliOp, rewriter: PatternRewriter): - if (rhs := const_evaluate_operand(op.rhs)) is None: - return - if rhs != 1: - return - - rewriter.replace_matched_op((), (op.lhs,)) - - -class MuliConstantProp(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.MuliOp, rewriter: PatternRewriter): - if (lhs := const_evaluate_operand(op.lhs)) is None: - return - if (rhs := const_evaluate_operand(op.rhs)) is None: - # Swap inputs if rhs is constant and lhs is not - rewriter.replace_matched_op(arith.MuliOp(op.rhs, op.lhs)) - return - - assert isinstance(op.result.type, IntegerType | IndexType) - rewriter.replace_matched_op( - arith.ConstantOp.from_int_and_width(lhs * rhs, op.result.type) - ) - - class ApplyCmpiPredicateToEqualOperands(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter): diff --git a/xdsl/transforms/canonicalization_patterns/cf.py b/xdsl/transforms/canonicalization_patterns/cf.py index eea4fb276e..0af86049d0 100644 --- a/xdsl/transforms/canonicalization_patterns/cf.py +++ b/xdsl/transforms/canonicalization_patterns/cf.py @@ -15,7 +15,9 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint -from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand +from xdsl.transforms.canonicalization_patterns.utils import ( + const_evaluate_operand, +) class AssertTrue(RewritePattern): diff --git a/xdsl/transforms/canonicalization_patterns/scf.py b/xdsl/transforms/canonicalization_patterns/scf.py index 2285070cc0..2e38bf15f4 100644 --- a/xdsl/transforms/canonicalization_patterns/scf.py +++ b/xdsl/transforms/canonicalization_patterns/scf.py @@ -9,7 +9,9 @@ ) from xdsl.rewriter import InsertPoint from xdsl.traits import ConstantLike -from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand +from xdsl.transforms.canonicalization_patterns.utils import ( + const_evaluate_operand, +) class RehoistConstInLoops(RewritePattern): diff --git a/xdsl/transforms/canonicalization_patterns/utils.py b/xdsl/transforms/canonicalization_patterns/utils.py index 273de0ec26..bab4fdb59a 100644 --- a/xdsl/transforms/canonicalization_patterns/utils.py +++ b/xdsl/transforms/canonicalization_patterns/utils.py @@ -1,13 +1,21 @@ from xdsl.dialects import arith -from xdsl.dialects.builtin import IntegerAttr +from xdsl.dialects.builtin import AnyIntegerAttr, IntegerAttr from xdsl.ir import SSAValue -def const_evaluate_operand(operand: SSAValue) -> int | None: +def const_evaluate_operand_attribute(operand: SSAValue) -> AnyIntegerAttr | None: """ Try to constant evaluate an SSA value, returning None on failure. """ if isinstance(op := operand.owner, arith.ConstantOp) and isinstance( val := op.value, IntegerAttr ): - return val.value.data + return val + + +def const_evaluate_operand(operand: SSAValue) -> int | None: + """ + Try to constant evaluate an SSA value, returning None on failure. + """ + if (attr := const_evaluate_operand_attribute(operand)) is not None: + return attr.value.data