Skip to content

Commit

Permalink
Union requires unique constr_id and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SCMusson committed Aug 23, 2024
1 parent 812cfb1 commit 22bec8f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
71 changes: 68 additions & 3 deletions opshin/tests/test_Unions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest

import hypothesis
from hypothesis import given
from hypothesis import strategies as st
Expand All @@ -10,6 +9,9 @@
hypothesis.settings.load_profile(PLUTUS_VM_PROFILE)

from .test_misc import A
from typing import List, Dict

from ..ledger.api_v2 import *


def to_int(x):
Expand All @@ -19,9 +21,9 @@ def to_int(x):
return 6
elif isinstance(x, bytes):
return 7
elif isinstance(x, list):
elif isinstance(x, List):
return 8
elif isinstance(x, dict):
elif isinstance(x, Dict):
return 9
return False

Expand Down Expand Up @@ -240,3 +242,66 @@ def validator(x: Union[A, int, bytes, List[Anything], Dict[int, bytes]]) -> int:
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, [1, 2, 3])
self.assertIsInstance(ce.exception.orig_err, AssertionError)

def test_same_constructor_fail(self):
@dataclass()
class B(PlutusData):
CONSTR_ID = 0
foo: int

@dataclass()
class C(PlutusData):
CONSTR_ID = 0
foo: int

source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
@dataclass()
class B(PlutusData):
CONSTR_ID = 0
foo: int
@dataclass()
class C(PlutusData):
CONSTR_ID = 0
foo: int
def validator(x: Union[B, C]) -> int:
return 100
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, B(0))
self.assertIsInstance(ce.exception.orig_err, AssertionError)

def test_str_fail(self):
source_code = """
def validator(x: Union[int, bytes, str]) -> int:
if isinstance(x, int):
return 5
elif isinstance(x, bytes):
return 6
elif isinstance(x, str):
return 7
return 100
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, "test")
self.assertIsInstance(ce.exception.orig_err, AssertionError)

def test_bool_fail(self):
source_code = """
def validator(x: Union[int, bytes, bool]) -> int:
if isinstance(x, int):
return 5
elif isinstance(x, bytes):
return 6
elif isinstance(x, bool):
return 7
return 100
"""
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, True)
self.assertIsInstance(ce.exception.orig_err, AssertionError)
16 changes: 16 additions & 0 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,21 @@ def type_from_annotation(self, ann: expr):
ann_types = frozenlist(
[self.type_from_annotation(e) for e in ann.slice.elts]
)
# check for unique constr_ids
constr_ids = [
record.record.constructor
for record in ann_types
if isinstance(record, RecordType)
]
assert len(constr_ids) == len(
set(constr_ids)
), f"Duplicate constr_ids for records in Union: " + str(
{
t.record.orig_name: t.record.constructor
for t in ann_types
if isinstance(t, RecordType)
}
)
return union_types(*ann_types)
if ann.value.orig_id == "List":
ann_type = self.type_from_annotation(ann.slice)
Expand Down Expand Up @@ -836,6 +851,7 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
"Dict",
"List",
]:

ts.value = ts.typ = self.type_from_annotation(ts)
return ts

Expand Down

0 comments on commit 22bec8f

Please sign in to comment.