Skip to content

Commit

Permalink
implement arith.bitcast operation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakedves committed Jan 20, 2025
1 parent 29266f4 commit 51c18b9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AddiOp,
AddUIExtendedOp,
AndIOp,
BitcastOp,
CeilDivSIOp,
CeilDivUIOp,
CmpfOp,
Expand Down Expand Up @@ -55,6 +56,7 @@
IndexType,
IntegerAttr,
IntegerType,
Signedness,
TensorType,
VectorType,
f32,
Expand Down Expand Up @@ -249,6 +251,27 @@ def test_select_op():
assert select_f_op.result.type == f.result.type


def test_bitcast_op():
signedness = [Signedness.SIGNED, Signedness.UNSIGNED, Signedness.SIGNLESS]

for bitwidth in [1, 2, 4, 5, 8, 16, 32, 64, 128]:
for from_sign in signedness:
for to_sign in signedness:
if from_sign == to_sign:
continue

# here create variable of from type
value = 0 if bitwidth in [1, 2] else 5
a = ConstantOp(
IntegerAttr(value, IntegerType(bitwidth, signedness=from_sign))
)
cast = BitcastOp(a, IntegerType(bitwidth, signedness=to_sign))

assert cast.result.type == IntegerType(bitwidth, signedness=to_sign)
assert cast.input.type == IntegerType(bitwidth, signedness=from_sign)
assert cast.input.owner == a


def test_index_cast_op():
a = ConstantOp.from_int_and_width(0, 32)
cast = IndexCastOp(a, IndexType())
Expand Down
34 changes: 34 additions & 0 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,40 @@ class MinnumfOp(FloatingPointLikeBinaryOperation):
traits = traits_def(Pure())


@irdl_op_definition
class BitcastOp(IRDLOperation):
"""
Reinterpret a value as another type of equal bitwidth without changing
the underlying representation.
"""

name = "arith.bitcast"

input = operand_def(IntegerType | Float16Type | Float32Type | Float64Type)

result = result_def(IntegerType | Float16Type | Float32Type | Float64Type)

assembly_format = "$input attr-dict `:` type($input) `to` type($result)"

def __init__(self, input_arg: SSAValue | Operation, target_type: Attribute):
super().__init__(operands=[input_arg], result_types=[target_type])

def verify_(self) -> None:
assert isinstance(
self.input.type, IntegerType | Float16Type | Float32Type | Float64Type
)
assert isinstance(
self.result.type, IntegerType | Float16Type | Float32Type | Float64Type
)

ibw = self.input.type.width
obw = self.result.type.width
if ibw == obw:
raise VerifyException(
f"'arith.bitcast' can only be used on types with equal bitwidths, found {ibw} and {obw}"
)


@irdl_op_definition
class IndexCastOp(IRDLOperation):
name = "arith.index_cast"
Expand Down

0 comments on commit 51c18b9

Please sign in to comment.