From 432f6f47b094e94fc1bf967d20e714cdfd6316f0 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 30 Jan 2024 18:04:59 -0500 Subject: [PATCH] Tailcalls --- src/dispatch/coroutine.py | 15 +++++++++++++++ tests/test_fastapi.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index bb05bdc0..43ca7202 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -231,6 +231,21 @@ def callback(cls, state: Any, calls: None | list[Call] = None) -> Output: return Output(coroutine_pb2.ExecuteResponse(poll=poll)) + @classmethod + def tailcall(cls, call: Call) -> Output: + """Exit the coroutine instructing the orchestrator to call the provided + coroutine.""" + input_bytes = _pb_any_pickle(call.input) + x = coroutine_pb2.Call( + coroutine_uri=call.coroutine_uri, + coroutine_version=call.coroutine_version, + correlation_id=call.correlation_id, + input=input_bytes, + ) + return Output( + coroutine_pb2.ExecuteResponse(exit=coroutine_pb2.Exit(tail_call=x)) + ) + # Note: contrary to other classes here Call is not just a wrapper around its # associated protobuf class, because it is reasonable for a human to write the diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 33f0c31b..d7c16a72 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -291,3 +291,18 @@ def mycoro(input: Input) -> Output: self.assertEqual("foo", resp.exit.result.error.type) self.assertEqual("bar", resp.exit.result.error.message) self.assertEqual(Status.THROTTLED, resp.status) + + def test_tailcall(self): + @self.app.dispatch_coroutine() + def other_coroutine(input: Input) -> Output: + return Output.value(f"Hello {input.input}") + + @self.app.dispatch_coroutine() + def mycoro(input: Input) -> Output: + return Output.tailcall(other_coroutine.call_with(42)) + + resp = self.execute(mycoro) + self.assertEqual(other_coroutine.uri, resp.exit.tail_call.coroutine_uri) + self.assertEqual( + 42, dispatch.coroutine._any_unpickle(resp.exit.tail_call.input) + )