Skip to content

Commit

Permalink
deserialize OrionRunner ReturnValue type
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Jan 21, 2024
1 parent d4a241e commit bd8a6ae
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 30 deletions.
6 changes: 0 additions & 6 deletions osiris/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,6 @@ def deserialize(serialized: str, data_type: str, fp_impl: str = 'FP16x16'):
"""
typer.echo("🚀 Starting deserialization process...")

try:
serialized = json.loads(serialized)
except json.JSONDecodeError as e:
typer.echo(f"Error: Invalid JSON - {e}")
raise typer.Exit(code=1) from e

deserialized = deserializer(serialized, data_type, fp_impl)
typer.echo("✅ Deserialization completed! 🎉")

Expand Down
42 changes: 41 additions & 1 deletion osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import numpy as np

from .utils import from_fp


def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'):
def deserializer(serialized: str, data_type: str, fp_impl='FP16x16'):
"""
Main deserialization function that handles various data types.
Expand All @@ -12,6 +13,9 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'):
:param fp_impl: The implementation detail, used for fixed-point deserialization.
:return: The deserialized data.
"""

serialized = convert_data(serialized)

if data_type == 'unsigned_int':
return deserialize_unsigned_int(serialized)
elif data_type == 'signed_int':
Expand All @@ -28,6 +32,8 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'):
return deserialize_tensor_uint(serialized)
elif data_type == 'tensor_signed_int':
return deserialize_tensor_signed_int(serialized)
elif data_type == 'tensor_fixed_point':
return deserialize_tensor_fixed_point(serialized)
# TODO: Support Tuples
# elif data_type == 'tensor_fixed_point':
# return deserialize_tensor_fixed_point(serialized, fp_impl)
Expand All @@ -46,6 +52,40 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'):
else:
raise ValueError(f"Unknown data type: {data_type}")


def parse_return_value(return_value):
"""
Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues).
"""
if 'Int' in return_value:
# Convert hexadecimal string to integer
return int(return_value['Int'], 16)
elif 'Array' in return_value:
# Recursively parse each item in the array
return [parse_return_value(item) for item in return_value['Array']]
else:
raise ValueError("Invalid ReturnValue format")


def convert_data(data):
"""
Convert the given JSON-like data structure to the desired format.
"""
parsed_data = json.loads(data)
result = []
for item in parsed_data:
# Parse each item based on its keys
if 'Array' in item:
# Process array items
result.append(parse_return_value(item))
elif 'Int' in item:
# Process single int items
result.append(parse_return_value(item))
else:
raise ValueError("Invalid data format")
return result


# ================= UNSIGNED INT =================


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-osiris"
version = "0.1.7"
version = "0.1.8"
description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs"
authors = ["Fran Algaba <fran@gizatech.xyz>"]
readme = "README.md"
Expand Down
43 changes: 21 additions & 22 deletions tests/test_deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,61 @@


def test_deserialize_signed_int():
serialized = [42, 0]
deserialized = deserialize_signed_int(serialized)
serialized = '[{"Int":"2A"}, {"Int":"0"}]'
deserialized = deserializer(serialized, 'signed_int')
assert deserialized == 42

serialized = [42, 1]
deserialized = deserialize_signed_int(serialized)
serialized = '[{"Int":"2A"}, {"Int":"0x1"}]'
deserialized = deserializer(serialized, 'signed_int')
assert deserialized == -42


def test_deserialize_signed_int():
serialized = [2780037, 0]
deserialized = deserialize_fixed_point(serialized, 'FP16x16')
def test_deserialize_fp():
serialized = '[{"Int":"2A6B85"}, {"Int":"0"}]'
deserialized = deserializer(serialized, 'fixed_point', 'FP16x16')
assert isclose(deserialized, 42.42, rel_tol=1e-7)

serialized = [2780037, 1]
deserialized = deserialize_fixed_point(serialized, 'FP16x16')
serialized = '[{"Int":"2A6B85"}, {"Int":"1"}]'
deserialized = deserializer(serialized, 'fixed_point', 'FP16x16')
assert isclose(deserialized, -42.42, rel_tol=1e-7)


def test_deserialize_array_uint():
serialized = [[1, 2]]
deserialized = deserialize_arr_uint(serialized)
serialized = '[{"Array": [{"Int": "0x1"}, {"Int": "0x2"}]}]'
deserialized = deserializer(serialized, 'arr_uint')
assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64))


def test_deserialize_array_signed_int():
serialized = [[42, 0, 42, 1]]
deserialized = deserialize_arr_signed_int(serialized)
serialized = '[{"Array": [{"Int": "2A"}, {"Int": "0"}, {"Int": "2A"}, {"Int": "0x1"}]}]'
deserialized = deserializer(serialized, 'arr_signed_int')
assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64))


def test_deserialize_arr_fixed_point():
serialized = [[2780037, 0, 2780037, 1]]
deserialized = deserialize_arr_fixed_point(serialized)
serialized = '[{"Array": [{"Int": "2A6B85"}, {"Int": "0"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
deserialized = deserializer(serialized, 'arr_fixed_point')
expected = np.array([42.42, -42.42], dtype=np.float64)
assert np.all(np.isclose(deserialized, expected, atol=1e-7))


def test_deserialize_tensor_uint():
serialized = [[2, 2], [1, 2, 3, 4]]
deserialized = deserialize_tensor_uint(serialized)
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "0x1"}, {"Int": "0x2"}, {"Int": "0x3"}, {"Int": "0x4"}]}]'
deserialized = deserializer(serialized, 'tensor_uint')
assert np.array_equal(deserialized, np.array(
([1, 2], [3, 4]), dtype=np.int64))


def test_deserialize_tensor_signed_int():
serialized_tensor = [[2, 2], [42, 0, 42, 0, 42, 1, 42, 1]]
deserialized = deserialize_tensor_signed_int(serialized_tensor)
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A"}, {"Int": "0x0"}, {"Int": "2A"}, {"Int": "0x0"}, {"Int": "2A"}, {"Int": "0x1"}, {"Int": "2A"}, {"Int": "0x1"}]}]'
deserialized = deserializer(serialized, 'tensor_signed_int')
assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]]))


def test_deserialize_tensor_fixed_point():
serialized_tensor = [[2, 2], [2780037,
0, 2780037, 0, 2780037, 1, 2780037, 1]]
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserialize_tensor_fixed_point(serialized_tensor)
deserialized = deserializer(serialized, 'tensor_fixed_point')
assert np.allclose(deserialized, expected_array, atol=1e-7)


Expand Down

0 comments on commit bd8a6ae

Please sign in to comment.