Skip to content

Commit

Permalink
Use the new container for pickled values
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 13, 2024
1 parent 5221d30 commit 84c5a34
Showing 1 changed file with 40 additions and 25 deletions.
65 changes: 40 additions & 25 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import tblib # type: ignore[import-untyped]
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.error import IncompatibleStateError
from dispatch.error import IncompatibleStateError, InvalidArgumentError
from dispatch.id import DispatchID
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import error_pb2 as error_pb
from dispatch.sdk.v1 import exit_pb2 as exit_pb
Expand Down Expand Up @@ -78,16 +79,7 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
try:
self._input = pickle.loads(input_bytes)
except Exception as e:
self._input = input_bytes
else:
self._input = _pb_any_unpack(req.input)
self._input = _pb_any_unpack(req.input)
else:
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
Expand Down Expand Up @@ -450,21 +442,44 @@ def _as_proto(self) -> error_pb.Error:


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue())
return pickle.loads(value_bytes.value)


def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
value_bytes = pickle.dumps(x)
pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes)
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any

if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
return pickle.loads(b.value)

raise InvalidArgumentError("unsupported pickled value container")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any = google.protobuf.any_pb2.Any()
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any:
pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(x.TypeName())
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
x.Unpack(proto)
any.Unpack(proto)
return proto

0 comments on commit 84c5a34

Please sign in to comment.