diff --git a/src/dispatch/any.py b/src/dispatch/any.py index d885709..84dbcb2 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -12,12 +12,15 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: + if not isinstance(value, google.protobuf.message.Message): + value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) + any = google.protobuf.any_pb2.Any() - if isinstance(value, google.protobuf.message.Message): - any.Pack(value) + if value.DESCRIPTOR.full_name.startswith("dispatch.sdk."): + any.Pack(value, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") else: - p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) - any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") + any.Pack(value) + return any diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py new file mode 100644 index 0000000..ee1d993 --- /dev/null +++ b/tests/dispatch/test_any.py @@ -0,0 +1,65 @@ +import pickle +from datetime import datetime, timedelta + +from dispatch.any import marshal_any, unmarshal_any +from dispatch.sdk.v1 import error_pb2 as error_pb + + +def test_unmarshal_none(): + boxed = marshal_any(None) + assert None == unmarshal_any(boxed) + + +def test_unmarshal_bool(): + boxed = marshal_any(True) + assert True == unmarshal_any(boxed) + + +def test_unmarshal_integer(): + boxed = marshal_any(1234) + assert 1234 == unmarshal_any(boxed) + + boxed = marshal_any(-1234) + assert -1234 == unmarshal_any(boxed) + + +def test_unmarshal_float(): + boxed = marshal_any(3.14) + assert 3.14 == unmarshal_any(boxed) + + +def test_unmarshal_string(): + boxed = marshal_any("foo") + assert "foo" == unmarshal_any(boxed) + + +def test_unmarshal_bytes(): + boxed = marshal_any(b"bar") + assert b"bar" == unmarshal_any(boxed) + + +def test_unmarshal_timestamp(): + ts = datetime.fromtimestamp( + 1719372909.641448 + ) # datetime.datetime(2024, 6, 26, 13, 35, 9, 641448) + boxed = marshal_any(ts) + assert ts == unmarshal_any(boxed) + + +def test_unmarshal_duration(): + d = timedelta(seconds=1, microseconds=1234) + boxed = marshal_any(d) + assert d == unmarshal_any(boxed) + + +def test_unmarshal_protobuf_message(): + message = error_pb.Error(type="internal", message="oops") + boxed = marshal_any(message) + + # Check the message isn't pickled (in which case the type_url would + # end with dispatch.sdk.python.v1.Pickled). + assert ( + "buf.build/stealthrocket/dispatch-proto/dispatch.sdk.v1.Error" == boxed.type_url + ) + + assert message == unmarshal_any(boxed)