Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap pickled values in dispatch.sdk.python.v1 container #177

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import tblib # type: ignore[import-untyped]
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.error import IncompatibleStateError, InvalidArgumentError
from dispatch.id import DispatchID
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import error_pb2 as error_pb
from dispatch.sdk.v1 import exit_pb2 as exit_pb
Expand Down Expand Up @@ -77,18 +79,11 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
try:
self._input = pickle.loads(input_bytes)
except Exception as e:
self._input = input_bytes
else:
self._input = _pb_any_unpack(req.input)
self._input = _pb_any_unpack(req.input)
else:
self._coroutine_state = req.poll_result.coroutine_state
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
Expand Down Expand Up @@ -155,15 +150,15 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
def from_poll_results(
cls,
function: str,
coroutine_state: Optional[bytes],
coroutine_state: Any,
call_results: List[CallResult],
error: Optional[Error] = None,
):
return Input(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
coroutine_state=coroutine_state,
typed_coroutine_state=_pb_any_pickle(coroutine_state),
results=[result._as_proto() for result in call_results],
error=error._as_proto() if error else None,
),
Expand Down Expand Up @@ -232,7 +227,7 @@ def exit(
@classmethod
def poll(
cls,
coroutine_state: Optional[bytes] = None,
coroutine_state: Any = None,
calls: Optional[List[Call]] = None,
min_results: int = 1,
max_results: int = 10,
Expand All @@ -247,7 +242,7 @@ def poll(
else None
)
poll = poll_pb.Poll(
coroutine_state=coroutine_state,
typed_coroutine_state=_pb_any_pickle(coroutine_state),
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down Expand Up @@ -447,21 +442,47 @@ def _as_proto(self) -> error_pb.Error:


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue())
return pickle.loads(value_bytes.value)
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
return pickle.loads(b.value)

elif not any.type_url and not any.value:
return None

raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any = google.protobuf.any_pb2.Any()
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
value_bytes = pickle.dumps(x)
pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes)
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any
def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any:
pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(x.TypeName())
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
x.Unpack(proto)
any.Unpack(proto)
return proto
44 changes: 17 additions & 27 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,17 @@ def _init_state(self, input: Input) -> State:
)

def _rebuild_state(self, input: Input):
logger.debug(
"resuming scheduler with %d bytes of state", len(input.coroutine_state)
)
logger.info("resuming main coroutine")
try:
state = pickle.loads(input.coroutine_state)
state = input.coroutine_state
if not isinstance(state, State):
raise ValueError("invalid state")
if state.version != self.version:
raise ValueError(
f"version mismatch: '{state.version}' vs. current '{self.version}'"
)
return state
except (pickle.PickleError, ValueError) as e:
except ValueError as e:
logger.warning("state is incompatible", exc_info=True)
raise IncompatibleStateError from e

Expand Down Expand Up @@ -454,32 +452,24 @@ async def _run(self, input: Input) -> Output:
await asyncio.gather(*asyncio_tasks, return_exceptions=True)
return coroutine_result

# Serialize coroutines and scheduler state.
logger.debug("serializing state")
# Yield to Dispatch.
logger.debug("yielding to Dispatch with %d call(s)", len(pending_calls))
try:
serialized_state = pickle.dumps(state)
return Output.poll(
coroutine_state=state,
calls=pending_calls,
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
max_wait_seconds=self.poll_max_wait_seconds,
)
except pickle.PickleError as e:
logger.exception("state could not be serialized")
return Output.error(Error.from_exception(e, status=Status.PERMANENT_ERROR))

# Close coroutines before yielding.
for suspended in state.suspended.values():
suspended.coroutine.close()
state.suspended = {}

# Yield to Dispatch.
logger.debug(
"yielding to Dispatch with %d call(s) and %d bytes of state",
len(pending_calls),
len(serialized_state),
)
return Output.poll(
coroutine_state=serialized_state,
calls=pending_calls,
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
max_wait_seconds=self.poll_max_wait_seconds,
)
finally:
# Close coroutines.
for suspended in state.suspended.values():
suspended.coroutine.close()
state.suspended = {}


async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):
Expand Down
Empty file.
Empty file.
32 changes: 32 additions & 0 deletions src/dispatch/sdk/python/v1/pickled_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions src/dispatch/sdk/python/v1/pickled_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import ClassVar as _ClassVar
from typing import Optional as _Optional

from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message

DESCRIPTOR: _descriptor.FileDescriptor

class Pickled(_message.Message):
__slots__ = ("pickled_value",)
PICKLED_VALUE_FIELD_NUMBER: _ClassVar[int]
pickled_value: bytes
def __init__(self, pickled_value: _Optional[bytes] = ...) -> None: ...
3 changes: 3 additions & 0 deletions src/dispatch/sdk/python/v1/pickled_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
8 changes: 4 additions & 4 deletions src/dispatch/sdk/v1/call_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 11 additions & 1 deletion src/dispatch/sdk/v1/call_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,34 @@ from dispatch.sdk.v1 import error_pb2 as _error_pb2
DESCRIPTOR: _descriptor.FileDescriptor

class Call(_message.Message):
__slots__ = ("correlation_id", "endpoint", "function", "input", "expiration")
__slots__ = (
"correlation_id",
"endpoint",
"function",
"input",
"expiration",
"version",
)
CORRELATION_ID_FIELD_NUMBER: _ClassVar[int]
ENDPOINT_FIELD_NUMBER: _ClassVar[int]
FUNCTION_FIELD_NUMBER: _ClassVar[int]
INPUT_FIELD_NUMBER: _ClassVar[int]
EXPIRATION_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
correlation_id: int
endpoint: str
function: str
input: _any_pb2.Any
expiration: _duration_pb2.Duration
version: str
def __init__(
self,
correlation_id: _Optional[int] = ...,
endpoint: _Optional[str] = ...,
function: _Optional[str] = ...,
input: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...,
expiration: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ...,
version: _Optional[str] = ...,
) -> None: ...

class CallResult(_message.Message):
Expand Down
14 changes: 7 additions & 7 deletions src/dispatch/sdk/v1/dispatch_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading