From 1d30f3bf5b569989ff80d16077d1d2113cad277e Mon Sep 17 00:00:00 2001 From: SCM Date: Mon, 5 Aug 2024 19:12:51 +0100 Subject: [PATCH 1/5] Lift resctiorn on Unions with bytes and int --- opshin/compiler.py | 9 ++++++++- opshin/fun_impls.py | 42 ++++++++++++++++++++++++++++++++++++++++ opshin/type_inference.py | 17 ++++++++++++---- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index e3e00222..ac8b0e27 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -422,7 +422,14 @@ def visit_Call(self, node: TypedCall) -> plt.AST: bind_self = node.func.typ.typ.bind_self bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys())) args = [] - for a, t in zip(node.args, node.func.typ.typ.argtyps): + for i, (a, t) in enumerate(zip(node.args, node.func.typ.typ.argtyps)): + # now impl_from_args has been chosen, skip type arg + if ( + hasattr(node.func, "orig_id") + and node.func.orig_id == "isinstance" + and i == 1 + ): + continue assert isinstance(t, InstanceType) # pass in all arguments evaluated with the statemonad a_int = self.visit(a) diff --git a/opshin/fun_impls.py b/opshin/fun_impls.py index b53b8a84..0c50ddc7 100644 --- a/opshin/fun_impls.py +++ b/opshin/fun_impls.py @@ -87,6 +87,46 @@ def impl_from_args(self, args: typing.List[Type]) -> plt.AST: return print +class IsinstanceImpl(PolymorphicFunction): + def type_from_args(self, args: typing.List[Type]) -> FunctionType: + assert ( + len(args) == 2 + ), f"isinstance takes two arguments [object, type], but {len(args)} were given" + # Plutus dataclasses isinstance is replaced by checking CONSTR_IDs + assert isinstance(args[1], (IntegerType, ByteStringType)) + return FunctionType(args, BoolInstanceType) + + def impl_from_args(self, args: typing.List[Type]) -> plt.AST: + if isinstance(args[1], IntegerType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(True), + plt.Bool(False), + ), + ) + elif isinstance(args[1], ByteStringType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(True), + ), + ) + else: + raise NotImplementedError( + f"Only isinstance for byte, int, Plutus Dataclass types are supported" + ) + + class PythonBuiltIn(Enum): all = OLambda( ["xs"], @@ -437,6 +477,7 @@ class PythonBuiltIn(Enum): OVar("xs"), plt.BuiltIn(uplc.BuiltInFun.AddInteger), plt.Integer(0) ), ) + isinstance = "isinstance" PythonBuiltInTypes = { @@ -510,4 +551,5 @@ class PythonBuiltIn(Enum): IntegerInstanceType, ) ), + PythonBuiltIn.isinstance: InstanceType(PolymorphicFunctionType(IsinstanceImpl())), } diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 9edd745c..e3f79bb1 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -135,13 +135,17 @@ def union_types(*ts: Type): for e in ts: for e2 in e.typs: assert isinstance( - e2, RecordType + e2, (RecordType, IntegerType, ByteStringType) ), f"Union must combine multiple PlutusData classes but found {e2.__class__.__name__}" union_set = OrderedSet() for t in ts: union_set.update(t.typs) assert distinct( - [e.record.constructor for e in union_set] + [ + e.record.constructor + for e in union_set + if not isinstance(e, (ByteStringType, IntegerType)) + ] ), "Union must combine PlutusData classes with unique constructors" return UnionType(frozenlist(union_set)) @@ -192,7 +196,7 @@ def visit_Call(self, node: Call) -> TypeMapPair: assert isinstance( inst_class, InstanceType ), "Can only cast instances, not classes" - assert isinstance(target_class, RecordType), "Can only cast to PlutusData" + # assert isinstance(target_class, RecordType), "Can only cast to PlutusData" if isinstance(inst_class.typ, UnionType): assert ( target_class in inst_class.typ.typs @@ -932,7 +936,11 @@ def visit_Call(self, node: Call) -> TypedCall: tc.args = [self.visit(a) for a in node.args] # might be isinstance - if isinstance(tc.func, Name) and tc.func.orig_id == "isinstance": + if ( + isinstance(tc.func, Name) + and tc.func.orig_id == "isinstance" + and not isinstance(tc.args[1].typ, (ByteStringType, IntegerType)) + ): target_class = tc.args[1].typ if ( isinstance(tc.args[0].typ, InstanceType) @@ -943,6 +951,7 @@ def visit_Call(self, node: Call) -> TypedCall: "OpShin does not permit checking the instance of raw Anything/Datum objects as this only checks the equality of the constructor id and nothing more. " "If you are certain of what you are doing, please use the flag '--allow-isinstance-anything'." ) + ntc = Compare( left=Attribute(tc.args[0], "CONSTR_ID"), ops=[Eq()], From e0e5d25b7702c2f930f5719a5cf5826af4d01d46 Mon Sep 17 00:00:00 2001 From: SCM Date: Mon, 5 Aug 2024 23:53:41 +0100 Subject: [PATCH 2/5] support lists and dicts in unions --- opshin/fun_impls.py | 38 ++++++++++++++++++++++++++++++++++++-- opshin/type_inference.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/opshin/fun_impls.py b/opshin/fun_impls.py index 0c50ddc7..7a65dee8 100644 --- a/opshin/fun_impls.py +++ b/opshin/fun_impls.py @@ -92,8 +92,6 @@ def type_from_args(self, args: typing.List[Type]) -> FunctionType: assert ( len(args) == 2 ), f"isinstance takes two arguments [object, type], but {len(args)} were given" - # Plutus dataclasses isinstance is replaced by checking CONSTR_IDs - assert isinstance(args[1], (IntegerType, ByteStringType)) return FunctionType(args, BoolInstanceType) def impl_from_args(self, args: typing.List[Type]) -> plt.AST: @@ -121,6 +119,42 @@ def impl_from_args(self, args: typing.List[Type]) -> plt.AST: plt.Bool(True), ), ) + elif isinstance(args[1], RecordType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(True), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + ), + ) + elif isinstance(args[1], ListType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(False), + plt.Bool(True), + plt.Bool(False), + plt.Bool(False), + ), + ) + elif isinstance(args[1], DictType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(True), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + ), + ) else: raise NotImplementedError( f"Only isinstance for byte, int, Plutus Dataclass types are supported" diff --git a/opshin/type_inference.py b/opshin/type_inference.py index e3f79bb1..6a52cda6 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -135,7 +135,7 @@ def union_types(*ts: Type): for e in ts: for e2 in e.typs: assert isinstance( - e2, (RecordType, IntegerType, ByteStringType) + e2, (RecordType, IntegerType, ByteStringType, ListType, DictType) ), f"Union must combine multiple PlutusData classes but found {e2.__class__.__name__}" union_set = OrderedSet() for t in ts: @@ -144,7 +144,7 @@ def union_types(*ts: Type): [ e.record.constructor for e in union_set - if not isinstance(e, (ByteStringType, IntegerType)) + if not isinstance(e, (ByteStringType, IntegerType, ListType, DictType)) ] ), "Union must combine PlutusData classes with unique constructors" return UnionType(frozenlist(union_set)) @@ -187,8 +187,8 @@ def visit_Call(self, node: Call) -> TypeMapPair: "Target 0 of an isinstance cast must be a variable name for type casting to work. You can still proceed, but the inferred type of the isinstance cast will not be accurate." ) return ({}, {}) - assert isinstance( - node.args[1], Name + assert isinstance(node.args[1], Name) or isinstance( + node.args[1].typ, (ListType, DictType) ), "Target 1 of an isinstance cast must be a class name" target_class: RecordType = node.args[1].typ inst = node.args[0] @@ -709,10 +709,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: self.enter_scope() tfd.args = self.visit(node.args) + functyp = FunctionType( frozenlist([t.typ for t in tfd.args.args]), InstanceType(self.type_from_annotation(tfd.returns)), - bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)}, + bound_vars={ + v: self.variable_type(v) + for v in externally_bound_vars(node) + if not v in ["List", "Dict"] + }, bind_self=node.name if node.name in read_vars(node) else None, ) tfd.typ = InstanceType(functyp) @@ -939,7 +944,10 @@ def visit_Call(self, node: Call) -> TypedCall: if ( isinstance(tc.func, Name) and tc.func.orig_id == "isinstance" - and not isinstance(tc.args[1].typ, (ByteStringType, IntegerType)) + and not isinstance( + tc.args[1].typ, (ByteStringType, IntegerType, ListType, DictType) + ) + and not hasattr(node, "skip_next") ): target_class = tc.args[1].typ if ( @@ -951,7 +959,6 @@ def visit_Call(self, node: Call) -> TypedCall: "OpShin does not permit checking the instance of raw Anything/Datum objects as this only checks the equality of the constructor id and nothing more. " "If you are certain of what you are doing, please use the flag '--allow-isinstance-anything'." ) - ntc = Compare( left=Attribute(tc.args[0], "CONSTR_ID"), ops=[Eq()], @@ -961,7 +968,17 @@ def visit_Call(self, node: Call) -> TypedCall: ntc = self.visit(ntc) ntc.typ = BoolInstanceType ntc.typechecks = TypeCheckVisitor(self.allow_isinstance_anything).visit(tc) - return ntc + if isinstance(tc.args[0].typ.typ, UnionType) and any( + [ + isinstance(a, (IntegerType, ByteStringType, ListType, DictType)) + for a in tc.args[0].typ.typ.typs + ] + ): + n = copy(node) + n.skip_next = True + return self.visit(BoolOp(And(), [n, ntc])) + else: + return ntc try: tc.func = self.visit(node.func) except Exception as e: From e12dd63af1acd990cf4932ea876babe6584c0e09 Mon Sep 17 00:00:00 2001 From: SCM Date: Wed, 7 Aug 2024 20:21:18 +0100 Subject: [PATCH 3/5] Restrict to only List[Anything] and Dict[Anything, Anything] allowed in Union --- opshin/type_inference.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 6a52cda6..5ab87cfe 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -380,6 +380,17 @@ def type_from_annotation(self, ann: expr): ann.value, Name ), "Only Union, Dict and List are allowed as Generic types" if ann.value.orig_id == "Union": + for elt in ann.slice.elts: + if isinstance(elt, Subscript) and elt.value.id == "List": + assert ( + isinstance(elt.slice, Name) + and elt.slice.orig_id == "Anything" + ), f"Only List[Anything] is supported in Unions. Received List[{elt.slice.orig_id}]." + if isinstance(elt, Subscript) and elt.value.id == "Dict": + assert all( + isinstance(e, Name) and e.orig_id == "Anything" + for e in elt.slice.elts + ), f"Only Dict[Anything, Anything] or Dict is supported in Unions. Received Dict[{elt.slice.elts[0].orig_id}, {elt.slice.elts[1].orig_id}]." ann_types = frozenlist( [self.type_from_annotation(e) for e in ann.slice.elts] ) From 812cfb1e5f60ec7cae73a8bba2b73b8b1142b07b Mon Sep 17 00:00:00 2001 From: SCM Date: Sun, 18 Aug 2024 15:30:44 +0100 Subject: [PATCH 4/5] Tests and Dict/List only in isinstance --- opshin/tests/test_Unions.py | 242 ++++++++++++++++++++++++++++++++++++ opshin/type_inference.py | 31 ++++- 2 files changed, 269 insertions(+), 4 deletions(-) create mode 100644 opshin/tests/test_Unions.py diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py new file mode 100644 index 00000000..37d64333 --- /dev/null +++ b/opshin/tests/test_Unions.py @@ -0,0 +1,242 @@ +import unittest + +import hypothesis +from hypothesis import given +from hypothesis import strategies as st +from .utils import eval_uplc_value +from . import PLUTUS_VM_PROFILE +from opshin.util import CompilerError + +hypothesis.settings.load_profile(PLUTUS_VM_PROFILE) + +from .test_misc import A + + +def to_int(x): + if isinstance(x, A): + return 5 + elif isinstance(x, int): + return 6 + elif isinstance(x, bytes): + return 7 + elif isinstance(x, list): + return 8 + elif isinstance(x, dict): + return 9 + return False + + +union_types = st.sampled_from([A(0), 10, b"foo", [1, 2, 3, 4, 5], {1: 2, 2: 3}]) + + +class Union_tests(unittest.TestCase): + @hypothesis.given(union_types) + def test_Union_types(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 0 + if isinstance(x, A): + k = 5 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, int): + k = 6 + elif isinstance(x, List): + k = 8 + elif isinstance(x, Dict): + k = 9 + return k +""" + res = eval_uplc_value(source_code, x) + self.assertEqual(res, to_int(x)) + + @hypothesis.given(union_types) + def test_Union_types_different_order(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + res = eval_uplc_value(source_code, x) + self.assertEqual(res, to_int(x)) + + @unittest.expectedFailure + def test_incorrect_Union_types( + self, + ): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, bytes,]) -> int: + k: int = 0 + if isinstance(x, A): + k = 5 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, int): + k = 6 + return k +""" + with self.AssertRaises(CompilerError): + res = eval_uplc_value(source_code, 2) + + def test_isinstance_Dict_subscript_fail( + self, + ): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict[Anything, Anything]): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, TypeError) + + def test_isinstance_List_subscript_fail( + self, + ): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List[Anything]): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, TypeError) + + def test_Union_list_is_anything( + self, + ): + """Test fails if List in union is anything other than List[Anything]""" + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[int], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_Union_dict_is_anything( + self, + ): + """Test fails if Dict in union is anything other than Dict[Anything, Anything]""" + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[int, bytes]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, AssertionError) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 5ab87cfe..7f05be4a 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -456,7 +456,10 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: assert ( func.returns is None or func.returns.id != n.name ), "Invalid Python, class name is undefined at this stage" - func.args.args[0].annotation = ast.Name(id=n.name, ctx=ast.Load()) + ann = ast.Name(id=n.name, ctx=ast.Load()) + custom_fix_missing_locations(ann, attribute.args.args[0]) + ann.orig_id = attribute.args.args[0].orig_arg + func.args.args[0].annotation = ann additional_functions.append(func) n.body = non_method_attributes if additional_functions: @@ -663,8 +666,14 @@ def visit_For(self, node: For) -> TypedFor: def visit_Name(self, node: Name) -> TypedName: tn = copy(node) - # Make sure that the rhs of an assign is evaluated first - tn.typ = self.variable_type(node.id) + # typing List and Dict are not present in scope we don't want to call variable_type + if node.orig_id == "List": + tn.typ = ListType(InstanceType(AnyType())) + elif node.orig_id == "Dict": + tn.typ = DictType(InstanceType(AnyType()), InstanceType(AnyType())) + else: + # Make sure that the rhs of an assign is evaluated first + tn.typ = self.variable_type(node.id) return tn def visit_keyword(self, node: keyword) -> Typedkeyword: @@ -952,6 +961,17 @@ def visit_Call(self, node: Call) -> TypedCall: tc.args = [self.visit(a) for a in node.args] # might be isinstance + # Subscripts are not allowed in isinstance calls + if ( + isinstance(tc.func, Name) + and tc.func.orig_id == "isinstance" + and isinstance(tc.args[1], Subscript) + ): + raise TypeError( + "Subscripted generics cannot be used with class and instance checks" + ) + + # Need to handle the presence of PlutusData classes if ( isinstance(tc.func, Name) and tc.func.orig_id == "isinstance" @@ -1003,9 +1023,12 @@ def visit_Call(self, node: Call) -> TypedCall: except Exception: # if this fails raise original error raise e - tc.func = self.visit(ast.Name(id=method_name, ctx=ast.Load())) + n = ast.Name(id=method_name, ctx=ast.Load()) + n.orig_id = node.func.attr + tc.func = self.visit(n) tc.func.orig_id = node.func.attr c_self = ast.Name(id=node.func.value.id, ctx=ast.Load()) + c_self.orig_id = None tc.args.insert(0, self.visit(c_self)) # might be a class From 22bec8fec6405b2abb57753dd0b0883200b21e57 Mon Sep 17 00:00:00 2001 From: SCM Date: Fri, 23 Aug 2024 11:46:59 +0100 Subject: [PATCH 5/5] Union requires unique constr_id and more tests --- opshin/tests/test_Unions.py | 71 +++++++++++++++++++++++++++++++++++-- opshin/type_inference.py | 16 +++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py index 37d64333..b9632e2e 100644 --- a/opshin/tests/test_Unions.py +++ b/opshin/tests/test_Unions.py @@ -1,5 +1,4 @@ import unittest - import hypothesis from hypothesis import given from hypothesis import strategies as st @@ -10,6 +9,9 @@ hypothesis.settings.load_profile(PLUTUS_VM_PROFILE) from .test_misc import A +from typing import List, Dict + +from ..ledger.api_v2 import * def to_int(x): @@ -19,9 +21,9 @@ def to_int(x): return 6 elif isinstance(x, bytes): return 7 - elif isinstance(x, list): + elif isinstance(x, List): return 8 - elif isinstance(x, dict): + elif isinstance(x, Dict): return 9 return False @@ -240,3 +242,66 @@ def validator(x: Union[A, int, bytes, List[Anything], Dict[int, bytes]]) -> int: with self.assertRaises(CompilerError) as ce: res = eval_uplc_value(source_code, [1, 2, 3]) self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_same_constructor_fail(self): + @dataclass() + class B(PlutusData): + CONSTR_ID = 0 + foo: int + + @dataclass() + class C(PlutusData): + CONSTR_ID = 0 + foo: int + + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class B(PlutusData): + CONSTR_ID = 0 + foo: int + +@dataclass() +class C(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[B, C]) -> int: + return 100 +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, B(0)) + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_str_fail(self): + source_code = """ +def validator(x: Union[int, bytes, str]) -> int: + if isinstance(x, int): + return 5 + elif isinstance(x, bytes): + return 6 + elif isinstance(x, str): + return 7 + return 100 +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, "test") + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_bool_fail(self): + source_code = """ +def validator(x: Union[int, bytes, bool]) -> int: + if isinstance(x, int): + return 5 + elif isinstance(x, bytes): + return 6 + elif isinstance(x, bool): + return 7 + return 100 +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, True) + self.assertIsInstance(ce.exception.orig_err, AssertionError) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 7f05be4a..4c7c3d3d 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -394,6 +394,21 @@ def type_from_annotation(self, ann: expr): ann_types = frozenlist( [self.type_from_annotation(e) for e in ann.slice.elts] ) + # check for unique constr_ids + constr_ids = [ + record.record.constructor + for record in ann_types + if isinstance(record, RecordType) + ] + assert len(constr_ids) == len( + set(constr_ids) + ), f"Duplicate constr_ids for records in Union: " + str( + { + t.record.orig_name: t.record.constructor + for t in ann_types + if isinstance(t, RecordType) + } + ) return union_types(*ann_types) if ann.value.orig_id == "List": ann_type = self.type_from_annotation(ann.slice) @@ -836,6 +851,7 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript: "Dict", "List", ]: + ts.value = ts.typ = self.type_from_annotation(ts) return ts