diff --git a/opshin/compiler.py b/opshin/compiler.py index e3e00222..ac8b0e27 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -422,7 +422,14 @@ def visit_Call(self, node: TypedCall) -> plt.AST: bind_self = node.func.typ.typ.bind_self bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys())) args = [] - for a, t in zip(node.args, node.func.typ.typ.argtyps): + for i, (a, t) in enumerate(zip(node.args, node.func.typ.typ.argtyps)): + # now impl_from_args has been chosen, skip type arg + if ( + hasattr(node.func, "orig_id") + and node.func.orig_id == "isinstance" + and i == 1 + ): + continue assert isinstance(t, InstanceType) # pass in all arguments evaluated with the statemonad a_int = self.visit(a) diff --git a/opshin/fun_impls.py b/opshin/fun_impls.py index b53b8a84..7a65dee8 100644 --- a/opshin/fun_impls.py +++ b/opshin/fun_impls.py @@ -87,6 +87,80 @@ def impl_from_args(self, args: typing.List[Type]) -> plt.AST: return print +class IsinstanceImpl(PolymorphicFunction): + def type_from_args(self, args: typing.List[Type]) -> FunctionType: + assert ( + len(args) == 2 + ), f"isinstance takes two arguments [object, type], but {len(args)} were given" + return FunctionType(args, BoolInstanceType) + + def impl_from_args(self, args: typing.List[Type]) -> plt.AST: + if isinstance(args[1], IntegerType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(True), + plt.Bool(False), + ), + ) + elif isinstance(args[1], ByteStringType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(True), + ), + ) + elif isinstance(args[1], RecordType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(True), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + ), + ) + elif isinstance(args[1], ListType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(False), + plt.Bool(True), + plt.Bool(False), + plt.Bool(False), + ), + ) + elif isinstance(args[1], DictType): + return OLambda( + ["x"], + plt.ChooseData( + OVar("x"), + plt.Bool(False), + plt.Bool(True), + plt.Bool(False), + plt.Bool(False), + plt.Bool(False), + ), + ) + else: + raise NotImplementedError( + f"Only isinstance for byte, int, Plutus Dataclass types are supported" + ) + + class PythonBuiltIn(Enum): all = OLambda( ["xs"], @@ -437,6 +511,7 @@ class PythonBuiltIn(Enum): OVar("xs"), plt.BuiltIn(uplc.BuiltInFun.AddInteger), plt.Integer(0) ), ) + isinstance = "isinstance" PythonBuiltInTypes = { @@ -510,4 +585,5 @@ class PythonBuiltIn(Enum): IntegerInstanceType, ) ), + PythonBuiltIn.isinstance: InstanceType(PolymorphicFunctionType(IsinstanceImpl())), } diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py new file mode 100644 index 00000000..b9632e2e --- /dev/null +++ b/opshin/tests/test_Unions.py @@ -0,0 +1,307 @@ +import unittest +import hypothesis +from hypothesis import given +from hypothesis import strategies as st +from .utils import eval_uplc_value +from . import PLUTUS_VM_PROFILE +from opshin.util import CompilerError + +hypothesis.settings.load_profile(PLUTUS_VM_PROFILE) + +from .test_misc import A +from typing import List, Dict + +from ..ledger.api_v2 import * + + +def to_int(x): + if isinstance(x, A): + return 5 + elif isinstance(x, int): + return 6 + elif isinstance(x, bytes): + return 7 + elif isinstance(x, List): + return 8 + elif isinstance(x, Dict): + return 9 + return False + + +union_types = st.sampled_from([A(0), 10, b"foo", [1, 2, 3, 4, 5], {1: 2, 2: 3}]) + + +class Union_tests(unittest.TestCase): + @hypothesis.given(union_types) + def test_Union_types(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 0 + if isinstance(x, A): + k = 5 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, int): + k = 6 + elif isinstance(x, List): + k = 8 + elif isinstance(x, Dict): + k = 9 + return k +""" + res = eval_uplc_value(source_code, x) + self.assertEqual(res, to_int(x)) + + @hypothesis.given(union_types) + def test_Union_types_different_order(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + res = eval_uplc_value(source_code, x) + self.assertEqual(res, to_int(x)) + + @unittest.expectedFailure + def test_incorrect_Union_types( + self, + ): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, bytes,]) -> int: + k: int = 0 + if isinstance(x, A): + k = 5 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, int): + k = 6 + return k +""" + with self.AssertRaises(CompilerError): + res = eval_uplc_value(source_code, 2) + + def test_isinstance_Dict_subscript_fail( + self, + ): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict[Anything, Anything]): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, TypeError) + + def test_isinstance_List_subscript_fail( + self, + ): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List[Anything]): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, TypeError) + + def test_Union_list_is_anything( + self, + ): + """Test fails if List in union is anything other than List[Anything]""" + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[int], Dict[Anything, Anything]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_Union_dict_is_anything( + self, + ): + """Test fails if Dict in union is anything other than Dict[Anything, Anything]""" + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[A, int, bytes, List[Anything], Dict[int, bytes]]) -> int: + k: int = 1 + if isinstance(x, int): + k = 6 + elif isinstance(x, Dict): + k = 9 + elif isinstance(x, bytes): + k = 7 + elif isinstance(x, A): + k = 5 + elif isinstance(x, List): + k = 8 + return k +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, [1, 2, 3]) + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_same_constructor_fail(self): + @dataclass() + class B(PlutusData): + CONSTR_ID = 0 + foo: int + + @dataclass() + class C(PlutusData): + CONSTR_ID = 0 + foo: int + + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class B(PlutusData): + CONSTR_ID = 0 + foo: int + +@dataclass() +class C(PlutusData): + CONSTR_ID = 0 + foo: int + +def validator(x: Union[B, C]) -> int: + return 100 +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, B(0)) + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_str_fail(self): + source_code = """ +def validator(x: Union[int, bytes, str]) -> int: + if isinstance(x, int): + return 5 + elif isinstance(x, bytes): + return 6 + elif isinstance(x, str): + return 7 + return 100 +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, "test") + self.assertIsInstance(ce.exception.orig_err, AssertionError) + + def test_bool_fail(self): + source_code = """ +def validator(x: Union[int, bytes, bool]) -> int: + if isinstance(x, int): + return 5 + elif isinstance(x, bytes): + return 6 + elif isinstance(x, bool): + return 7 + return 100 +""" + with self.assertRaises(CompilerError) as ce: + res = eval_uplc_value(source_code, True) + self.assertIsInstance(ce.exception.orig_err, AssertionError) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 9edd745c..4c7c3d3d 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -135,13 +135,17 @@ def union_types(*ts: Type): for e in ts: for e2 in e.typs: assert isinstance( - e2, RecordType + e2, (RecordType, IntegerType, ByteStringType, ListType, DictType) ), f"Union must combine multiple PlutusData classes but found {e2.__class__.__name__}" union_set = OrderedSet() for t in ts: union_set.update(t.typs) assert distinct( - [e.record.constructor for e in union_set] + [ + e.record.constructor + for e in union_set + if not isinstance(e, (ByteStringType, IntegerType, ListType, DictType)) + ] ), "Union must combine PlutusData classes with unique constructors" return UnionType(frozenlist(union_set)) @@ -183,8 +187,8 @@ def visit_Call(self, node: Call) -> TypeMapPair: "Target 0 of an isinstance cast must be a variable name for type casting to work. You can still proceed, but the inferred type of the isinstance cast will not be accurate." ) return ({}, {}) - assert isinstance( - node.args[1], Name + assert isinstance(node.args[1], Name) or isinstance( + node.args[1].typ, (ListType, DictType) ), "Target 1 of an isinstance cast must be a class name" target_class: RecordType = node.args[1].typ inst = node.args[0] @@ -192,7 +196,7 @@ def visit_Call(self, node: Call) -> TypeMapPair: assert isinstance( inst_class, InstanceType ), "Can only cast instances, not classes" - assert isinstance(target_class, RecordType), "Can only cast to PlutusData" + # assert isinstance(target_class, RecordType), "Can only cast to PlutusData" if isinstance(inst_class.typ, UnionType): assert ( target_class in inst_class.typ.typs @@ -376,9 +380,35 @@ def type_from_annotation(self, ann: expr): ann.value, Name ), "Only Union, Dict and List are allowed as Generic types" if ann.value.orig_id == "Union": + for elt in ann.slice.elts: + if isinstance(elt, Subscript) and elt.value.id == "List": + assert ( + isinstance(elt.slice, Name) + and elt.slice.orig_id == "Anything" + ), f"Only List[Anything] is supported in Unions. Received List[{elt.slice.orig_id}]." + if isinstance(elt, Subscript) and elt.value.id == "Dict": + assert all( + isinstance(e, Name) and e.orig_id == "Anything" + for e in elt.slice.elts + ), f"Only Dict[Anything, Anything] or Dict is supported in Unions. Received Dict[{elt.slice.elts[0].orig_id}, {elt.slice.elts[1].orig_id}]." ann_types = frozenlist( [self.type_from_annotation(e) for e in ann.slice.elts] ) + # check for unique constr_ids + constr_ids = [ + record.record.constructor + for record in ann_types + if isinstance(record, RecordType) + ] + assert len(constr_ids) == len( + set(constr_ids) + ), f"Duplicate constr_ids for records in Union: " + str( + { + t.record.orig_name: t.record.constructor + for t in ann_types + if isinstance(t, RecordType) + } + ) return union_types(*ann_types) if ann.value.orig_id == "List": ann_type = self.type_from_annotation(ann.slice) @@ -441,7 +471,10 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: assert ( func.returns is None or func.returns.id != n.name ), "Invalid Python, class name is undefined at this stage" - func.args.args[0].annotation = ast.Name(id=n.name, ctx=ast.Load()) + ann = ast.Name(id=n.name, ctx=ast.Load()) + custom_fix_missing_locations(ann, attribute.args.args[0]) + ann.orig_id = attribute.args.args[0].orig_arg + func.args.args[0].annotation = ann additional_functions.append(func) n.body = non_method_attributes if additional_functions: @@ -648,8 +681,14 @@ def visit_For(self, node: For) -> TypedFor: def visit_Name(self, node: Name) -> TypedName: tn = copy(node) - # Make sure that the rhs of an assign is evaluated first - tn.typ = self.variable_type(node.id) + # typing List and Dict are not present in scope we don't want to call variable_type + if node.orig_id == "List": + tn.typ = ListType(InstanceType(AnyType())) + elif node.orig_id == "Dict": + tn.typ = DictType(InstanceType(AnyType()), InstanceType(AnyType())) + else: + # Make sure that the rhs of an assign is evaluated first + tn.typ = self.variable_type(node.id) return tn def visit_keyword(self, node: keyword) -> Typedkeyword: @@ -705,10 +744,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: self.enter_scope() tfd.args = self.visit(node.args) + functyp = FunctionType( frozenlist([t.typ for t in tfd.args.args]), InstanceType(self.type_from_annotation(tfd.returns)), - bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)}, + bound_vars={ + v: self.variable_type(v) + for v in externally_bound_vars(node) + if not v in ["List", "Dict"] + }, bind_self=node.name if node.name in read_vars(node) else None, ) tfd.typ = InstanceType(functyp) @@ -807,6 +851,7 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript: "Dict", "List", ]: + ts.value = ts.typ = self.type_from_annotation(ts) return ts @@ -932,7 +977,25 @@ def visit_Call(self, node: Call) -> TypedCall: tc.args = [self.visit(a) for a in node.args] # might be isinstance - if isinstance(tc.func, Name) and tc.func.orig_id == "isinstance": + # Subscripts are not allowed in isinstance calls + if ( + isinstance(tc.func, Name) + and tc.func.orig_id == "isinstance" + and isinstance(tc.args[1], Subscript) + ): + raise TypeError( + "Subscripted generics cannot be used with class and instance checks" + ) + + # Need to handle the presence of PlutusData classes + if ( + isinstance(tc.func, Name) + and tc.func.orig_id == "isinstance" + and not isinstance( + tc.args[1].typ, (ByteStringType, IntegerType, ListType, DictType) + ) + and not hasattr(node, "skip_next") + ): target_class = tc.args[1].typ if ( isinstance(tc.args[0].typ, InstanceType) @@ -952,7 +1015,17 @@ def visit_Call(self, node: Call) -> TypedCall: ntc = self.visit(ntc) ntc.typ = BoolInstanceType ntc.typechecks = TypeCheckVisitor(self.allow_isinstance_anything).visit(tc) - return ntc + if isinstance(tc.args[0].typ.typ, UnionType) and any( + [ + isinstance(a, (IntegerType, ByteStringType, ListType, DictType)) + for a in tc.args[0].typ.typ.typs + ] + ): + n = copy(node) + n.skip_next = True + return self.visit(BoolOp(And(), [n, ntc])) + else: + return ntc try: tc.func = self.visit(node.func) except Exception as e: @@ -966,9 +1039,12 @@ def visit_Call(self, node: Call) -> TypedCall: except Exception: # if this fails raise original error raise e - tc.func = self.visit(ast.Name(id=method_name, ctx=ast.Load())) + n = ast.Name(id=method_name, ctx=ast.Load()) + n.orig_id = node.func.attr + tc.func = self.visit(n) tc.func.orig_id = node.func.attr c_self = ast.Name(id=node.func.value.id, ctx=ast.Load()) + c_self.orig_id = None tc.args.insert(0, self.visit(c_self)) # might be a class