From bceb6cfbb448f62cd6cfcfc8a592f84c13fedfe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 19 Jul 2024 20:59:45 +0200 Subject: [PATCH 1/2] Allow passing keywords for constructors --- opshin/tests/test_keywords.py | 78 +++++++++++++++++++++++++++++++---- opshin/type_inference.py | 4 ++ 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/opshin/tests/test_keywords.py b/opshin/tests/test_keywords.py index 91957733..10c9c989 100644 --- a/opshin/tests/test_keywords.py +++ b/opshin/tests/test_keywords.py @@ -47,8 +47,7 @@ def validator(a: int, b: int, c: int) -> int: ret = eval_uplc_value(source_code, x, y, z) self.assertEqual(ret, (x - y) * z) - @given(x=st.integers(), y=st.integers(), z=st.integers()) - def test_arg_after_keyword_failure(self, x: int, y: int, z: int): + def test_arg_after_keyword_failure(self): source_code = """ def simple_example(x: int, y: int, z:int) -> int: return (x-y)*z @@ -57,10 +56,9 @@ def validator(a: int, b: int, c: int) -> int: return simple_example(x=a, y=b, c) """ with self.assertRaises(Exception): - ret = eval_uplc_value(source_code, x, y, z) + ret = eval_uplc_value(source_code, 1, 2, 3) - @given(x=st.integers(), y=st.integers(), z=st.integers()) - def test_too_many_keywords_failure(self, x: int, y: int, z: int): + def test_too_many_keywords_failure(self): source_code = """ def simple_example(x: int, y: int) -> int: return x-y @@ -69,10 +67,9 @@ def validator(a: int, b: int, c: int) -> int: return simple_example(x=a, y=b, z=c) """ with self.assertRaises(Exception): - ret = eval_uplc_value(source_code, x, y, z) + ret = eval_uplc_value(source_code, 1, 2, 3) - @given(x=st.integers(), y=st.integers(), z=st.integers()) - def test_incorrect_keywords_failure(self, x: int, y: int, z: int): + def test_incorrect_keywords_failure(self): source_code = """ def simple_example(x: int, y: int, z: int) -> int: return (x-y)*z @@ -81,7 +78,7 @@ def validator(a: int, b: int, c: int) -> int: return simple_example(x=a, y=b, k=c) """ with self.assertRaises(Exception): - ret = eval_uplc_value(source_code, x, y, z) + ret = eval_uplc_value(source_code, 1, 2, 3) @given(x=st.integers(), y=st.integers(), z=st.integers()) def test_correct_scope(self, x: int, y: int, z: int): @@ -96,3 +93,66 @@ def validator(a: int, b: int, c: int) -> int: """ ret = eval_uplc_value(source_code, x, y, z) self.assertEqual(ret, (x - z) * y) + + def test_type_mismatch(self): + source_code = """ +def simple_example(x: int, y: int, z: int) -> int: + return x * y + z + +def validator(a: int, b: bytes, c: int) -> int: + return simple_example(x=a, y=b, z=c) +""" + with self.assertRaises(Exception): + ret = eval_uplc_value(source_code, 1, 2, 3) + + @given(x=st.integers()) + def test_class_keywords(self, x: int): + source_code = """ +from opshin.prelude import * + +@dataclass +class A(PlutusData): + x: int + y: int + z: int + + + +def validator(a: int, b: int, c: int) -> int: + return A(x=a, y=b, z=c).x +""" + ret = eval_uplc_value(source_code, x, 2, 3) + self.assertEqual(ret, x) + + @given(x=st.integers()) + def test_class_keywords_reorder(self, x: int): + source_code = """ +from opshin.prelude import * + +@dataclass +class A(PlutusData): + x: int + y: int + z: int + +def validator(a: int, b: int, c: int) -> int: + return A(y=a, z=b, x=c).x +""" + ret = eval_uplc_value(source_code, 1, 2, x) + self.assertEqual(ret, x) + + def test_class_keywords_invalid(self): + source_code = """ +from opshin.prelude import * + +@dataclass +class A(PlutusData): + x: int + y: int + z: int + +def validator(a: int, b: bytes, c: int) -> int: + return A(x=a, y=b, z=c).x +""" + with self.assertRaises(Exception): + ret = eval_uplc_value(source_code, 1, b"2", 3) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 5a10d064..9d8a4820 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -361,6 +361,10 @@ def visit_ClassDef(self, node: ClassDef) -> TypedClassDef: class_record = RecordReader.extract(node, self) typ = RecordType(class_record) self.set_variable_type(node.name, typ) + self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [ + typedarg(arg=field, typ=typ, orig_arg=field) + for field, typ in class_record.fields + ] typed_node = copy(node) typed_node.class_typ = typ return typed_node From 8336050e8b7a989ec10f84787e300510914e795e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 19 Jul 2024 21:12:11 +0200 Subject: [PATCH 2/2] Fix varname --- opshin/type_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 9d8a4820..51b5cc7e 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -362,8 +362,8 @@ def visit_ClassDef(self, node: ClassDef) -> TypedClassDef: typ = RecordType(class_record) self.set_variable_type(node.name, typ) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [ - typedarg(arg=field, typ=typ, orig_arg=field) - for field, typ in class_record.fields + typedarg(arg=field, typ=field_typ, orig_arg=field) + for field, field_typ in class_record.fields ] typed_node = copy(node) typed_node.class_typ = typ