Skip to content

Commit

Permalink
Merge pull request #387 from OpShin/feat/allow_class_keywords
Browse files Browse the repository at this point in the history
Allow passing keywords for constructors
  • Loading branch information
nielstron authored Jul 19, 2024
2 parents 8c9b34b + 8336050 commit eda1e3a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
78 changes: 69 additions & 9 deletions opshin/tests/test_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def validator(a: int, b: int, c: int) -> int:
ret = eval_uplc_value(source_code, x, y, z)
self.assertEqual(ret, (x - y) * z)

@given(x=st.integers(), y=st.integers(), z=st.integers())
def test_arg_after_keyword_failure(self, x: int, y: int, z: int):
def test_arg_after_keyword_failure(self):
source_code = """
def simple_example(x: int, y: int, z:int) -> int:
return (x-y)*z
Expand All @@ -57,10 +56,9 @@ def validator(a: int, b: int, c: int) -> int:
return simple_example(x=a, y=b, c)
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, x, y, z)
ret = eval_uplc_value(source_code, 1, 2, 3)

@given(x=st.integers(), y=st.integers(), z=st.integers())
def test_too_many_keywords_failure(self, x: int, y: int, z: int):
def test_too_many_keywords_failure(self):
source_code = """
def simple_example(x: int, y: int) -> int:
return x-y
Expand All @@ -69,10 +67,9 @@ def validator(a: int, b: int, c: int) -> int:
return simple_example(x=a, y=b, z=c)
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, x, y, z)
ret = eval_uplc_value(source_code, 1, 2, 3)

@given(x=st.integers(), y=st.integers(), z=st.integers())
def test_incorrect_keywords_failure(self, x: int, y: int, z: int):
def test_incorrect_keywords_failure(self):
source_code = """
def simple_example(x: int, y: int, z: int) -> int:
return (x-y)*z
Expand All @@ -81,7 +78,7 @@ def validator(a: int, b: int, c: int) -> int:
return simple_example(x=a, y=b, k=c)
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, x, y, z)
ret = eval_uplc_value(source_code, 1, 2, 3)

@given(x=st.integers(), y=st.integers(), z=st.integers())
def test_correct_scope(self, x: int, y: int, z: int):
Expand All @@ -96,3 +93,66 @@ def validator(a: int, b: int, c: int) -> int:
"""
ret = eval_uplc_value(source_code, x, y, z)
self.assertEqual(ret, (x - z) * y)

def test_type_mismatch(self):
source_code = """
def simple_example(x: int, y: int, z: int) -> int:
return x * y + z
def validator(a: int, b: bytes, c: int) -> int:
return simple_example(x=a, y=b, z=c)
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, 1, 2, 3)

@given(x=st.integers())
def test_class_keywords(self, x: int):
source_code = """
from opshin.prelude import *
@dataclass
class A(PlutusData):
x: int
y: int
z: int
def validator(a: int, b: int, c: int) -> int:
return A(x=a, y=b, z=c).x
"""
ret = eval_uplc_value(source_code, x, 2, 3)
self.assertEqual(ret, x)

@given(x=st.integers())
def test_class_keywords_reorder(self, x: int):
source_code = """
from opshin.prelude import *
@dataclass
class A(PlutusData):
x: int
y: int
z: int
def validator(a: int, b: int, c: int) -> int:
return A(y=a, z=b, x=c).x
"""
ret = eval_uplc_value(source_code, 1, 2, x)
self.assertEqual(ret, x)

def test_class_keywords_invalid(self):
source_code = """
from opshin.prelude import *
@dataclass
class A(PlutusData):
x: int
y: int
z: int
def validator(a: int, b: bytes, c: int) -> int:
return A(x=a, y=b, z=c).x
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, 1, b"2", 3)
4 changes: 4 additions & 0 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ def visit_ClassDef(self, node: ClassDef) -> TypedClassDef:
class_record = RecordReader.extract(node, self)
typ = RecordType(class_record)
self.set_variable_type(node.name, typ)
self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [
typedarg(arg=field, typ=field_typ, orig_arg=field)
for field, field_typ in class_record.fields
]
typed_node = copy(node)
typed_node.class_typ = typ
return typed_node
Expand Down

0 comments on commit eda1e3a

Please sign in to comment.