From 58b762592b46e7d6984342471da5efd46e4e5d06 Mon Sep 17 00:00:00 2001 From: emmau678 <77412390+emmau678@users.noreply.github.com> Date: Thu, 19 Sep 2024 14:29:07 +0100 Subject: [PATCH] dialects: (riscv) Add rewrite pattern to optimize bitwise xor by zero (#3197) Add a rewrite pattern to the risv dialect for to optimise bitwise xor by zero (x^0 -> x) --------- Co-authored-by: emmau678 --- GETTING_STARTED.md | 1 - tests/filecheck/backend/riscv/canonicalize.mlir | 12 ++++++++++++ xdsl/dialects/riscv.py | 7 +++++-- .../transforms/canonicalization_patterns/riscv.py | 15 +++++++++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index b78a86d32c..ec2f63f1fb 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -47,7 +47,6 @@ You're welcome to come up with your own, or do one of the following: - `x * 2ⁱ -> x << i` - `x & 0 -> 0` - `x | 0 -> x` -- `x ^ 0 -> x` The patterns are defined in [xdsl/transforms/canonicalization_patterns/riscv.py](xdsl/transforms/canonicalization_patterns/riscv.py). diff --git a/tests/filecheck/backend/riscv/canonicalize.mlir b/tests/filecheck/backend/riscv/canonicalize.mlir index 5911baafce..042ec574aa 100644 --- a/tests/filecheck/backend/riscv/canonicalize.mlir +++ b/tests/filecheck/backend/riscv/canonicalize.mlir @@ -111,6 +111,12 @@ builtin.module { %xor_lhs_rhs = riscv.xor %i1, %i1 : (!riscv.reg, !riscv.reg) -> !riscv.reg "test.op"(%xor_lhs_rhs) : (!riscv.reg) -> () + %xor_bitwise_zero_l0 = riscv.xor %c1, %c0 : (!riscv.reg, !riscv.reg) -> !riscv.reg + "test.op"(%xor_bitwise_zero_l0) : (!riscv.reg) -> () + + %xor_bitwise_zero_r0 = riscv.xor %c0, %c1 : (!riscv.reg, !riscv.reg) -> !riscv.reg + "test.op"(%xor_bitwise_zero_r0) : (!riscv.reg) -> () + // scfgw immediates riscv_snitch.scfgw %i1, %c1 : (!riscv.reg, !riscv.reg) -> () } @@ -221,6 +227,12 @@ builtin.module { // CHECK-NEXT: %xor_lhs_rhs_1 = riscv.mv %xor_lhs_rhs : (!riscv.reg) -> !riscv.reg // CHECK-NEXT: "test.op"(%xor_lhs_rhs_1) : (!riscv.reg) -> () +// CHECK-NEXT: %xor_bitwise_zero_l0 = riscv.mv %c1 : (!riscv.reg) -> !riscv.reg +// CHECK-NEXT: "test.op"(%xor_bitwise_zero_l0) : (!riscv.reg) -> () + +// CHECK-NEXT: %xor_bitwise_zero_r0 = riscv.mv %c1 : (!riscv.reg) -> !riscv.reg +// CHECK-NEXT: "test.op"(%xor_bitwise_zero_r0) : (!riscv.reg) -> () + // CHECK-NEXT: riscv_snitch.scfgwi %i1, 1 : (!riscv.reg) -> () // CHECK-NEXT: } diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index 094f34b8bf..d36755f942 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -1777,9 +1777,12 @@ class OrOp(RdRsRsOperation[IntRegisterType, IntRegisterType, IntRegisterType]): class BitwiseXorHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): @classmethod def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: - from xdsl.transforms.canonicalization_patterns.riscv import XorBySelf + from xdsl.transforms.canonicalization_patterns.riscv import ( + BitwiseXorByZero, + XorBySelf, + ) - return (XorBySelf(),) + return (XorBySelf(), BitwiseXorByZero()) @irdl_op_definition diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 575342ff6f..497ce6fa23 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -419,6 +419,21 @@ def match_and_rewrite(self, op: riscv.XorOp, rewriter: PatternRewriter): ) +class BitwiseXorByZero(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: riscv.XorOp, rewriter: PatternRewriter): + """ + x ^ 0 = x + """ + if (rs1 := get_constant_value(op.rs1)) is not None and rs1.value.data == 0: + rd = cast(riscv.IntRegisterType, op.rd.type) + rewriter.replace_matched_op(riscv.MVOp(op.rs2, rd=rd)) + + if (rs2 := get_constant_value(op.rs2)) is not None and rs2.value.data == 0: + rd = cast(riscv.IntRegisterType, op.rd.type) + rewriter.replace_matched_op(riscv.MVOp(op.rs1, rd=rd)) + + class ScfgwOpUsingImmediate(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(