Skip to content

Commit

Permalink
Tests and Dict/List only in isinstance
Browse files Browse the repository at this point in the history
  • Loading branch information
SCMusson committed Aug 18, 2024
1 parent e12dd63 commit 812cfb1
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 4 deletions.
242 changes: 242 additions & 0 deletions opshin/tests/test_Unions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import unittest

import hypothesis
from hypothesis import given
from hypothesis import strategies as st
from .utils import eval_uplc_value
from . import PLUTUS_VM_PROFILE
from opshin.util import CompilerError

hypothesis.settings.load_profile(PLUTUS_VM_PROFILE)

from .test_misc import A


def to_int(x):
if isinstance(x, A):
return 5
elif isinstance(x, int):
return 6
elif isinstance(x, bytes):
return 7
elif isinstance(x, list):
return 8
elif isinstance(x, dict):
return 9
return False


union_types = st.sampled_from([A(0), 10, b"foo", [1, 2, 3, 4, 5], {1: 2, 2: 3}])


class Union_tests(unittest.TestCase):
@hypothesis.given(union_types)
def test_Union_types(self, x):
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int:
k: int = 0
if isinstance(x, A):
k = 5
elif isinstance(x, bytes):
k = 7
elif isinstance(x, int):
k = 6
elif isinstance(x, List):
k = 8
elif isinstance(x, Dict):
k = 9
return k
"""
res = eval_uplc_value(source_code, x)
self.assertEqual(res, to_int(x))

@hypothesis.given(union_types)
def test_Union_types_different_order(self, x):
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int:
k: int = 1
if isinstance(x, int):
k = 6
elif isinstance(x, Dict):
k = 9
elif isinstance(x, bytes):
k = 7
elif isinstance(x, A):
k = 5
elif isinstance(x, List):
k = 8
return k
"""
res = eval_uplc_value(source_code, x)
self.assertEqual(res, to_int(x))

@unittest.expectedFailure
def test_incorrect_Union_types(
self,
):
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, bytes,]) -> int:
k: int = 0
if isinstance(x, A):
k = 5
elif isinstance(x, bytes):
k = 7
elif isinstance(x, int):
k = 6
return k
"""
with self.AssertRaises(CompilerError):
res = eval_uplc_value(source_code, 2)

def test_isinstance_Dict_subscript_fail(
self,
):
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int:
k: int = 1
if isinstance(x, int):
k = 6
elif isinstance(x, Dict[Anything, Anything]):
k = 9
elif isinstance(x, bytes):
k = 7
elif isinstance(x, A):
k = 5
elif isinstance(x, List):
k = 8
return k
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, [1, 2, 3])
self.assertIsInstance(ce.exception.orig_err, TypeError)

def test_isinstance_List_subscript_fail(
self,
):
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, int, bytes, List[Anything], Dict[Anything, Anything]]) -> int:
k: int = 1
if isinstance(x, int):
k = 6
elif isinstance(x, Dict):
k = 9
elif isinstance(x, bytes):
k = 7
elif isinstance(x, A):
k = 5
elif isinstance(x, List[Anything]):
k = 8
return k
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, [1, 2, 3])
self.assertIsInstance(ce.exception.orig_err, TypeError)

def test_Union_list_is_anything(
self,
):
"""Test fails if List in union is anything other than List[Anything]"""
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, int, bytes, List[int], Dict[Anything, Anything]]) -> int:
k: int = 1
if isinstance(x, int):
k = 6
elif isinstance(x, Dict):
k = 9
elif isinstance(x, bytes):
k = 7
elif isinstance(x, A):
k = 5
elif isinstance(x, List):
k = 8
return k
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, [1, 2, 3])
self.assertIsInstance(ce.exception.orig_err, AssertionError)

def test_Union_dict_is_anything(
self,
):
"""Test fails if Dict in union is anything other than Dict[Anything, Anything]"""
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[A, int, bytes, List[Anything], Dict[int, bytes]]) -> int:
k: int = 1
if isinstance(x, int):
k = 6
elif isinstance(x, Dict):
k = 9
elif isinstance(x, bytes):
k = 7
elif isinstance(x, A):
k = 5
elif isinstance(x, List):
k = 8
return k
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, [1, 2, 3])
self.assertIsInstance(ce.exception.orig_err, AssertionError)
31 changes: 27 additions & 4 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,10 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST:
assert (
func.returns is None or func.returns.id != n.name
), "Invalid Python, class name is undefined at this stage"
func.args.args[0].annotation = ast.Name(id=n.name, ctx=ast.Load())
ann = ast.Name(id=n.name, ctx=ast.Load())
custom_fix_missing_locations(ann, attribute.args.args[0])
ann.orig_id = attribute.args.args[0].orig_arg
func.args.args[0].annotation = ann
additional_functions.append(func)
n.body = non_method_attributes
if additional_functions:
Expand Down Expand Up @@ -663,8 +666,14 @@ def visit_For(self, node: For) -> TypedFor:

def visit_Name(self, node: Name) -> TypedName:
tn = copy(node)
# Make sure that the rhs of an assign is evaluated first
tn.typ = self.variable_type(node.id)
# typing List and Dict are not present in scope we don't want to call variable_type
if node.orig_id == "List":
tn.typ = ListType(InstanceType(AnyType()))
elif node.orig_id == "Dict":
tn.typ = DictType(InstanceType(AnyType()), InstanceType(AnyType()))
else:
# Make sure that the rhs of an assign is evaluated first
tn.typ = self.variable_type(node.id)
return tn

def visit_keyword(self, node: keyword) -> Typedkeyword:
Expand Down Expand Up @@ -952,6 +961,17 @@ def visit_Call(self, node: Call) -> TypedCall:
tc.args = [self.visit(a) for a in node.args]

# might be isinstance
# Subscripts are not allowed in isinstance calls
if (
isinstance(tc.func, Name)
and tc.func.orig_id == "isinstance"
and isinstance(tc.args[1], Subscript)
):
raise TypeError(
"Subscripted generics cannot be used with class and instance checks"
)

# Need to handle the presence of PlutusData classes
if (
isinstance(tc.func, Name)
and tc.func.orig_id == "isinstance"
Expand Down Expand Up @@ -1003,9 +1023,12 @@ def visit_Call(self, node: Call) -> TypedCall:
except Exception:
# if this fails raise original error
raise e
tc.func = self.visit(ast.Name(id=method_name, ctx=ast.Load()))
n = ast.Name(id=method_name, ctx=ast.Load())
n.orig_id = node.func.attr
tc.func = self.visit(n)
tc.func.orig_id = node.func.attr
c_self = ast.Name(id=node.func.value.id, ctx=ast.Load())
c_self.orig_id = None
tc.args.insert(0, self.visit(c_self))

# might be a class
Expand Down

0 comments on commit 812cfb1

Please sign in to comment.