From 3b34616de73041b6cbcd4a90d0602cb26412a711 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 30 Jan 2024 15:27:13 -0500 Subject: [PATCH] Coroutine that calls onto another --- src/dispatch/coroutine.py | 114 +++++++++++++++++++++++++++++++++++--- src/dispatch/fastapi.py | 21 +++---- tests/test_fastapi.py | 88 ++++++++++++++++++++++++----- 3 files changed, 190 insertions(+), 33 deletions(-) diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index 531d1cd3..f693d5df 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -18,6 +18,45 @@ from ring.coroutine.v1 import coroutine_pb2 +# Most types in this package are thin wrappers around the various protobuf +# messages of ring.coroutine.v1. They provide some safeguards and ergonomics. + + +class Coroutine: + """Callable wrapper around a function meant to be used throughout the + Dispatch Python SDK.""" + + def __init__(self, func): + self._func = func + + def __call__(self, *args, **kwargs): + return self._func(*args, **kwargs) + + @property + def uri(self) -> str: + return self._func.__qualname__ + + def call_with(self, input: Any, correlation_id: int | None = None) -> Call: + """Create a Call of this coroutine with the provided input. Useful to + generate calls during callbacks. + + Args: + input: any pickle-able python value that will be passed as input to + this coroutine. + correlation_id: optional arbitrary integer the caller can use to + match this call to a call result. + + Returns: + A Call object. It can likely be passed to Output.callback(). + """ + return Call( + coroutine_uri=self.uri, + coroutine_version="v1", + correlation_id=correlation_id, + input=input, + ) + + class Input: """The input to a coroutine. @@ -50,6 +89,7 @@ def __init__(self, req: coroutine_pb2.ExecuteRequest): self._state = pickle.loads(state_bytes) else: self._state = None + self._calls = [CallResult(r) for r in req.poll_response.results] @property def is_first_call(self) -> bool: @@ -61,15 +101,26 @@ def is_resume(self) -> bool: @property def input(self) -> Any: - if self.is_resume: - raise ValueError("This input is for a resumed coroutine") + self._assert_first_call() return self._input @property def state(self) -> Any: + self._assert_resume() + return self._state + + @property + def calls(self) -> Any: + self._assert_resume() + return self._calls + + def _assert_first_call(self): + if self.is_resume: + raise ValueError("This input is for a resumed coroutine") + + def _assert_resume(self): if self.is_first_call: raise ValueError("This input is for a first coroutine call") - return self._state class Output: @@ -93,14 +144,59 @@ def value(cls, value: Any) -> Output: ) @classmethod - def callback(cls, state: Any) -> Output: + def callback(cls, state: Any, calls: None | list[Call] = None) -> Output: """Exit the coroutine instructing the orchestrator to call back this coroutine with the provided state. The state will be made available in Input.state.""" state_bytes = pickle.dumps(state) - return Output( - coroutine_pb2.ExecuteResponse(poll=coroutine_pb2.Poll(state=state_bytes)) - ) + poll = coroutine_pb2.Poll(state=state_bytes) + + if calls is not None: + for c in calls: + input_bytes = _pb_any_pickle(c.input) + x = coroutine_pb2.Call( + coroutine_uri=c.coroutine_uri, + coroutine_version=c.coroutine_version, + correlation_id=c.correlation_id, + input=input_bytes, + ) + poll.calls.append(x) + + return Output(coroutine_pb2.ExecuteResponse(poll=poll)) + + +# Note: contrary to other classes here Call is not just a wrapper around its +# associated protobuf class, because it is reasonable for a human to write the +# Call manually -- for example to call a coroutine that cannot be referenced in +# the current Python process. + + +@dataclass +class Call: + """Instruction to invoke a coroutine. + + Though this class can be built manually, it is recommended to use the + with_call method of a Coroutine instead. + + """ + + coroutine_uri: str + coroutine_version: str + correlation_id: int | None + input: Any + + +class CallResult: + def __init__(self, proto: coroutine_pb2.CallResult): + self.coroutine_uri = proto.coroutine_uri + self.coroutine_version = proto.coroutine_version + self.correlation_id = proto.correlation_id + self.result = _any_unpickle(proto.result.output) + + +def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: + any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue()) + return pickle.loads(value_bytes.value) def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any: @@ -109,3 +205,7 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any: pb_any = google.protobuf.any_pb2.Any() pb_any.Pack(pb_bytes) return pb_any + + +def _coroutine_uri_to_qualname(coroutine_uri: str) -> str: + return coroutine_uri.split("/")[-1] diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 1e2f708d..9657578e 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -78,13 +78,12 @@ def dispatch_coroutine(self): ValueError: If the coroutine is already registered. """ - def wrap(coroutine: Callable[..., Any]): - if coroutine.__qualname__ in self._coroutines: - raise ValueError( - f"Coroutine {coroutine.__qualname__} already registered" - ) - self._coroutines[coroutine.__qualname__] = coroutine - return coroutine + def wrap(func: Callable[[dispatch.coroutine.Input], dispatch.coroutine.Output]): + coro = dispatch.coroutine.Coroutine(func) + if coro.uri in self._coroutines: + raise ValueError(f"Coroutine {coro.uri} already registered") + self._coroutines[coro.uri] = coro + return coro return wrap @@ -93,10 +92,6 @@ class _GRPCResponse(fastapi.Response): media_type = "application/grpc+proto" -def _coroutine_uri_to_qualname(coroutine_uri: str) -> str: - return coroutine_uri.split("/")[-1] - - def _new_app(): app = _DispatchAPI() app._coroutines = {} @@ -122,7 +117,9 @@ async def execute(request: fastapi.Request): # TODO: be more graceful. This will crash if the coroutine is not found, # and the coroutine version is not taken into account. - coroutine = app._coroutines[_coroutine_uri_to_qualname(req.coroutine_uri)] + coroutine = app._coroutines[ + dispatch.coroutine._coroutine_uri_to_qualname(req.coroutine_uri) + ] coro_input = dispatch.coroutine.Input(req) output = coroutine(coro_input) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index ec33490d..7cd28513 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -7,7 +7,7 @@ import fastapi from fastapi.testclient import TestClient import google.protobuf.wrappers_pb2 -import ring.coroutine.v1.coroutine_pb2 +from ring.coroutine.v1 import coroutine_pb2 from . import executor_service @@ -63,15 +63,15 @@ def my_cool_coroutine(input: Input) -> Output: input_any = google.protobuf.any_pb2.Any() input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) - req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest( - coroutine_uri=my_cool_coroutine.__qualname__, + req = coroutine_pb2.ExecuteRequest( + coroutine_uri=my_cool_coroutine.uri, coroutine_version="1", input=input_any, ) resp = client.Execute(req) - self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse) + self.assertIsInstance(resp, coroutine_pb2.ExecuteResponse) self.assertEqual(resp.coroutine_uri, req.coroutine_uri) self.assertEqual(resp.coroutine_version, req.coroutine_version) @@ -83,11 +83,8 @@ def my_cool_coroutine(input: Input) -> Output: self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") -def response_output(resp: ring.coroutine.v1.coroutine_pb2.ExecuteResponse) -> Any: - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - return pickle.loads(output_bytes.value) +def response_output(resp: coroutine_pb2.ExecuteResponse) -> Any: + return dispatch.coroutine._any_unpickle(resp.exit.result.output) class TestCoroutine(unittest.TestCase): @@ -97,11 +94,11 @@ def setUp(self): self.client = executor_service.client(http_client) def execute( - self, coroutine, input=None, state=None - ) -> ring.coroutine.v1.coroutine_pb2.ExecuteResponse: + self, coroutine, input=None, state=None, calls=None + ) -> coroutine_pb2.ExecuteResponse: """Test helper to invoke coroutines on the local server.""" - req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest( - coroutine_uri=coroutine.__qualname__, + req = coroutine_pb2.ExecuteRequest( + coroutine_uri=coroutine.uri, coroutine_version="1", ) @@ -112,11 +109,34 @@ def execute( req.input.CopyFrom(input_any) if state is not None: req.poll_response.state = state + if calls is not None: + for c in calls: + req.poll_response.results.append(c) resp = self.client.Execute(req) - self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse) + self.assertIsInstance(resp, coroutine_pb2.ExecuteResponse) return resp + def call(self, call: coroutine_pb2.Call) -> coroutine_pb2.CallResult: + req = coroutine_pb2.ExecuteRequest( + coroutine_uri=call.coroutine_uri, + coroutine_version=call.coroutine_version, + input=call.input, + ) + resp = self.client.Execute(req) + self.assertIsInstance(resp, coroutine_pb2.ExecuteResponse) + + # Assert the response is terminal. Good enough until the test client can + # orchestrate coroutines. + self.assertTrue(len(resp.poll.state) == 0) + + return coroutine_pb2.CallResult( + coroutine_uri=resp.coroutine_uri, + coroutine_version=resp.coroutine_version, + correlation_id=call.correlation_id, + result=resp.exit.result, + ) + def test_no_input(self): @self.app.dispatch_coroutine() def my_cool_coroutine(input: Input) -> Output: @@ -186,3 +206,43 @@ def coroutine3(input: Input) -> Output: self.assertTrue(len(state) == 0) out = response_output(resp) self.assertEqual(out, "done") + + def test_coroutine_poll(self): + @self.app.dispatch_coroutine() + def coro_compute_len(input: Input) -> Output: + return Output.value(len(input.input)) + + @self.app.dispatch_coroutine() + def coroutine_main(input: Input) -> Output: + if input.is_first_call: + text: str = input.input + return Output.callback( + state=text, calls=[coro_compute_len.call_with(text)] + ) + text = input.state + length = input.calls[0].result + return Output.value(f"length={length} text='{text}'") + + resp = self.execute(coroutine_main, input="cool stuff") + + # main saved some state + state = resp.poll.state + self.assertTrue(len(state) > 0) + # main asks for 1 call to compute_len + self.assertEqual(len(resp.poll.calls), 1) + call = resp.poll.calls[0] + self.assertEqual(call.coroutine_uri, coro_compute_len.uri) + self.assertEqual(dispatch.coroutine._any_unpickle(call.input), "cool stuff") + + # make the requested compute_len + resp2 = self.call(call) + # check the result is the terminal expected response + len_resp = dispatch.coroutine._any_unpickle(resp2.result.output) + self.assertEqual(10, len_resp) + + # resume main with the result + resp = self.execute(coroutine_main, state=state, calls=[resp2]) + # validate the final result + self.assertTrue(len(resp.poll.state) == 0) + out = response_output(resp) + self.assertEqual("length=10 text='cool stuff'", out)