From 51c18b95b6de274684c5d2095adf7e31cdd77eb5 Mon Sep 17 00:00:00 2001 From: Jake Date: Mon, 20 Jan 2025 18:12:42 +0000 Subject: [PATCH] implement `arith.bitcast` operation --- tests/dialects/test_arith.py | 23 +++++++++++++++++++++++ xdsl/dialects/arith.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/tests/dialects/test_arith.py b/tests/dialects/test_arith.py index 68ee9db90c..66c5f177b2 100644 --- a/tests/dialects/test_arith.py +++ b/tests/dialects/test_arith.py @@ -7,6 +7,7 @@ AddiOp, AddUIExtendedOp, AndIOp, + BitcastOp, CeilDivSIOp, CeilDivUIOp, CmpfOp, @@ -55,6 +56,7 @@ IndexType, IntegerAttr, IntegerType, + Signedness, TensorType, VectorType, f32, @@ -249,6 +251,27 @@ def test_select_op(): assert select_f_op.result.type == f.result.type +def test_bitcast_op(): + signedness = [Signedness.SIGNED, Signedness.UNSIGNED, Signedness.SIGNLESS] + + for bitwidth in [1, 2, 4, 5, 8, 16, 32, 64, 128]: + for from_sign in signedness: + for to_sign in signedness: + if from_sign == to_sign: + continue + + # here create variable of from type + value = 0 if bitwidth in [1, 2] else 5 + a = ConstantOp( + IntegerAttr(value, IntegerType(bitwidth, signedness=from_sign)) + ) + cast = BitcastOp(a, IntegerType(bitwidth, signedness=to_sign)) + + assert cast.result.type == IntegerType(bitwidth, signedness=to_sign) + assert cast.input.type == IntegerType(bitwidth, signedness=from_sign) + assert cast.input.owner == a + + def test_index_cast_op(): a = ConstantOp.from_int_and_width(0, 32) cast = IndexCastOp(a, IndexType()) diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 6850f33ffc..b73929d4f1 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -1223,6 +1223,40 @@ class MinnumfOp(FloatingPointLikeBinaryOperation): traits = traits_def(Pure()) +@irdl_op_definition +class BitcastOp(IRDLOperation): + """ + Reinterpret a value as another type of equal bitwidth without changing + the underlying representation. + """ + + name = "arith.bitcast" + + input = operand_def(IntegerType | Float16Type | Float32Type | Float64Type) + + result = result_def(IntegerType | Float16Type | Float32Type | Float64Type) + + assembly_format = "$input attr-dict `:` type($input) `to` type($result)" + + def __init__(self, input_arg: SSAValue | Operation, target_type: Attribute): + super().__init__(operands=[input_arg], result_types=[target_type]) + + def verify_(self) -> None: + assert isinstance( + self.input.type, IntegerType | Float16Type | Float32Type | Float64Type + ) + assert isinstance( + self.result.type, IntegerType | Float16Type | Float32Type | Float64Type + ) + + ibw = self.input.type.width + obw = self.result.type.width + if ibw == obw: + raise VerifyException( + f"'arith.bitcast' can only be used on types with equal bitwidths, found {ibw} and {obw}" + ) + + @irdl_op_definition class IndexCastOp(IRDLOperation): name = "arith.index_cast"