Skip to content

Commit

Permalink
dialects: (math) Use SameOperandsAndResultType trait (#3761)
Browse files Browse the repository at this point in the history
This PR:

- Uses the `SameOperandsAndResultType` trait in the `math` dialect
- Simplifies relevant test cases which were not using the `return_type`
parameter
  • Loading branch information
compor authored Jan 20, 2025
1 parent 7b2c906 commit 9b81223
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 92 deletions.
40 changes: 13 additions & 27 deletions tests/dialects/test_math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import TypeVar

import pytest

from xdsl.dialects.arith import ConstantOp, FloatingPointLikeBinaryOperation
Expand Down Expand Up @@ -45,12 +43,9 @@
TruncOp,
)
from xdsl.dialects.test import TestOp
from xdsl.ir import Attribute
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.test_value import TestSSAValue

_BinOpArgT = TypeVar("_BinOpArgT", bound=Attribute)


class Test_float_math_binary_construction:
operand_type = f32
Expand Down Expand Up @@ -81,11 +76,9 @@ class Test_float_math_binary_construction:
PowFOp,
],
)
@pytest.mark.parametrize("return_type", [None, operand_type])
def test_float_binary_ops_constant_math_init(
self,
OpClass: type[FloatingPointLikeBinaryOperation],
return_type: Attribute,
):
op = OpClass(self.a, self.b)
assert isinstance(op, OpClass)
Expand All @@ -102,9 +95,9 @@ def test_float_binary_ops_constant_math_init(
PowFOp,
],
)
@pytest.mark.parametrize("return_type", [None, f32_vector_type])
def test_flaot_binary_vector_ops_init(
self, OpClass: type[FloatingPointLikeBinaryOperation], return_type: Attribute
def test_float_binary_vector_ops_init(
self,
OpClass: type[FloatingPointLikeBinaryOperation],
):
op = OpClass(self.lhs_vector, self.rhs_vector)
assert isinstance(op, OpClass)
Expand All @@ -121,9 +114,9 @@ def test_flaot_binary_vector_ops_init(
PowFOp,
],
)
@pytest.mark.parametrize("return_type", [None, f32_tensor_type])
def test_float_binary_ops_tensor_math_init(
self, OpClass: type[FloatingPointLikeBinaryOperation], return_type: Attribute
self,
OpClass: type[FloatingPointLikeBinaryOperation],
):
op = OpClass(self.lhs_tensor, self.rhs_tensor)
assert isinstance(op, OpClass)
Expand Down Expand Up @@ -172,11 +165,9 @@ class Test_float_math_unary_constructions:
TruncOp,
],
)
@pytest.mark.parametrize("return_type", [operand_type])
def test_float_math_constant_ops_init(
self,
OpClass: type,
return_type: Attribute, # FIXME
):
op = OpClass(self.a)
assert op.result.type == f32
Expand Down Expand Up @@ -211,8 +202,7 @@ def test_float_math_constant_ops_init(
TruncOp,
],
)
@pytest.mark.parametrize("return_type", [f32_tensor_type])
def test_float_math_ops_vector_init(self, OpClass: type, return_type: Attribute):
def test_float_math_ops_vector_init(self, OpClass: type):
op = OpClass(self.test_vec)
assert op.result.type == self.f32_vector_type
assert op.operand.type == self.f32_vector_type
Expand Down Expand Up @@ -246,8 +236,10 @@ def test_float_math_ops_vector_init(self, OpClass: type, return_type: Attribute)
TruncOp,
],
)
@pytest.mark.parametrize("return_type", [f32_tensor_type])
def test_float_math_ops_tensor_init(self, OpClass: type, return_type: Attribute):
def test_float_math_ops_tensor_init(
self,
OpClass: type,
):
op = OpClass(self.test_tensor)
assert op.result.type == self.f32_tensor_type
assert op.operand.type == self.f32_tensor_type
Expand Down Expand Up @@ -351,11 +343,9 @@ class Test_int_math_unary_constructions:
CtPopOp,
],
)
@pytest.mark.parametrize("return_type", [operand_type])
def test_int_math_ops_init(
self,
OpClass: type,
return_type: Attribute, # FIXME use something other than `type`
):
op = OpClass(self.a)
assert op.result.type == i32
Expand All @@ -372,11 +362,9 @@ def test_int_math_ops_init(
CtPopOp,
],
)
@pytest.mark.parametrize("return_type", [i32_vector_type])
def test_int_math_ops_vec_init(
self,
OpClass: type,
return_type: Attribute, # FIXME use something other than `type`
):
op = OpClass(self.test_vec)
assert op.result.type == self.i32_vector_type
Expand All @@ -393,11 +381,9 @@ def test_int_math_ops_vec_init(
CtPopOp,
],
)
@pytest.mark.parametrize("return_type", [i32_vector_type])
def test_int_math_ops_tensor_init(
self,
OpClass: type,
return_type: Attribute,
):
op = OpClass(self.test_tensor)
assert op.result.type == self.i32_tensor_type
Expand All @@ -420,8 +406,8 @@ class Test_Trunci:

def test_trunci_incorrect_bitwidth(self):
with pytest.raises(VerifyException):
_trunci_op = TruncOp(self.a).verify()
TruncOp(self.a).verify()
with pytest.raises(VerifyException):
_trunci_op_vec = TruncOp(self.test_vec).verify()
TruncOp(self.test_vec).verify()
with pytest.raises(VerifyException):
_trunci_op_tensor = TruncOp(self.test_tensor).verify()
TruncOp(self.test_tensor).verify()
Loading

0 comments on commit 9b81223

Please sign in to comment.