From bd8a6aed76e0c09c0f338a854baa0b849209bd7a Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 21 Jan 2024 18:09:27 +0200 Subject: [PATCH] deserialize OrionRunner ReturnValue type --- osiris/app.py | 6 ----- osiris/cairo/serde/deserialize.py | 42 +++++++++++++++++++++++++++++- pyproject.toml | 2 +- tests/test_deserialize.py | 43 +++++++++++++++---------------- 4 files changed, 63 insertions(+), 30 deletions(-) diff --git a/osiris/app.py b/osiris/app.py index 1041cc6..9a09862 100644 --- a/osiris/app.py +++ b/osiris/app.py @@ -112,12 +112,6 @@ def deserialize(serialized: str, data_type: str, fp_impl: str = 'FP16x16'): """ typer.echo("🚀 Starting deserialization process...") - try: - serialized = json.loads(serialized) - except json.JSONDecodeError as e: - typer.echo(f"Error: Invalid JSON - {e}") - raise typer.Exit(code=1) from e - deserialized = deserializer(serialized, data_type, fp_impl) typer.echo("✅ Deserialization completed! 🎉") diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index bbf57be..afa7a91 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -1,9 +1,10 @@ +import json import numpy as np from .utils import from_fp -def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'): +def deserializer(serialized: str, data_type: str, fp_impl='FP16x16'): """ Main deserialization function that handles various data types. @@ -12,6 +13,9 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'): :param fp_impl: The implementation detail, used for fixed-point deserialization. :return: The deserialized data. """ + + serialized = convert_data(serialized) + if data_type == 'unsigned_int': return deserialize_unsigned_int(serialized) elif data_type == 'signed_int': @@ -28,6 +32,8 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'): return deserialize_tensor_uint(serialized) elif data_type == 'tensor_signed_int': return deserialize_tensor_signed_int(serialized) + elif data_type == 'tensor_fixed_point': + return deserialize_tensor_fixed_point(serialized) # TODO: Support Tuples # elif data_type == 'tensor_fixed_point': # return deserialize_tensor_fixed_point(serialized, fp_impl) @@ -46,6 +52,40 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'): else: raise ValueError(f"Unknown data type: {data_type}") + +def parse_return_value(return_value): + """ + Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues). + """ + if 'Int' in return_value: + # Convert hexadecimal string to integer + return int(return_value['Int'], 16) + elif 'Array' in return_value: + # Recursively parse each item in the array + return [parse_return_value(item) for item in return_value['Array']] + else: + raise ValueError("Invalid ReturnValue format") + + +def convert_data(data): + """ + Convert the given JSON-like data structure to the desired format. + """ + parsed_data = json.loads(data) + result = [] + for item in parsed_data: + # Parse each item based on its keys + if 'Array' in item: + # Process array items + result.append(parse_return_value(item)) + elif 'Int' in item: + # Process single int items + result.append(parse_return_value(item)) + else: + raise ValueError("Invalid data format") + return result + + # ================= UNSIGNED INT ================= diff --git a/pyproject.toml b/pyproject.toml index 2174ace..4a9d58c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "giza-osiris" -version = "0.1.7" +version = "0.1.8" description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs" authors = ["Fran Algaba "] readme = "README.md" diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index b6509dc..858e780 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -6,62 +6,61 @@ def test_deserialize_signed_int(): - serialized = [42, 0] - deserialized = deserialize_signed_int(serialized) + serialized = '[{"Int":"2A"}, {"Int":"0"}]' + deserialized = deserializer(serialized, 'signed_int') assert deserialized == 42 - serialized = [42, 1] - deserialized = deserialize_signed_int(serialized) + serialized = '[{"Int":"2A"}, {"Int":"0x1"}]' + deserialized = deserializer(serialized, 'signed_int') assert deserialized == -42 -def test_deserialize_signed_int(): - serialized = [2780037, 0] - deserialized = deserialize_fixed_point(serialized, 'FP16x16') +def test_deserialize_fp(): + serialized = '[{"Int":"2A6B85"}, {"Int":"0"}]' + deserialized = deserializer(serialized, 'fixed_point', 'FP16x16') assert isclose(deserialized, 42.42, rel_tol=1e-7) - serialized = [2780037, 1] - deserialized = deserialize_fixed_point(serialized, 'FP16x16') + serialized = '[{"Int":"2A6B85"}, {"Int":"1"}]' + deserialized = deserializer(serialized, 'fixed_point', 'FP16x16') assert isclose(deserialized, -42.42, rel_tol=1e-7) def test_deserialize_array_uint(): - serialized = [[1, 2]] - deserialized = deserialize_arr_uint(serialized) + serialized = '[{"Array": [{"Int": "0x1"}, {"Int": "0x2"}]}]' + deserialized = deserializer(serialized, 'arr_uint') assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64)) def test_deserialize_array_signed_int(): - serialized = [[42, 0, 42, 1]] - deserialized = deserialize_arr_signed_int(serialized) + serialized = '[{"Array": [{"Int": "2A"}, {"Int": "0"}, {"Int": "2A"}, {"Int": "0x1"}]}]' + deserialized = deserializer(serialized, 'arr_signed_int') assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64)) def test_deserialize_arr_fixed_point(): - serialized = [[2780037, 0, 2780037, 1]] - deserialized = deserialize_arr_fixed_point(serialized) + serialized = '[{"Array": [{"Int": "2A6B85"}, {"Int": "0"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]' + deserialized = deserializer(serialized, 'arr_fixed_point') expected = np.array([42.42, -42.42], dtype=np.float64) assert np.all(np.isclose(deserialized, expected, atol=1e-7)) def test_deserialize_tensor_uint(): - serialized = [[2, 2], [1, 2, 3, 4]] - deserialized = deserialize_tensor_uint(serialized) + serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "0x1"}, {"Int": "0x2"}, {"Int": "0x3"}, {"Int": "0x4"}]}]' + deserialized = deserializer(serialized, 'tensor_uint') assert np.array_equal(deserialized, np.array( ([1, 2], [3, 4]), dtype=np.int64)) def test_deserialize_tensor_signed_int(): - serialized_tensor = [[2, 2], [42, 0, 42, 0, 42, 1, 42, 1]] - deserialized = deserialize_tensor_signed_int(serialized_tensor) + serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A"}, {"Int": "0x0"}, {"Int": "2A"}, {"Int": "0x0"}, {"Int": "2A"}, {"Int": "0x1"}, {"Int": "2A"}, {"Int": "0x1"}]}]' + deserialized = deserializer(serialized, 'tensor_signed_int') assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]])) def test_deserialize_tensor_fixed_point(): - serialized_tensor = [[2, 2], [2780037, - 0, 2780037, 0, 2780037, 1, 2780037, 1]] + serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]' expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]]) - deserialized = deserialize_tensor_fixed_point(serialized_tensor) + deserialized = deserializer(serialized, 'tensor_fixed_point') assert np.allclose(deserialized, expected_array, atol=1e-7)