Skip to content

Commit

Permalink
Coroutine that calls onto another
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Jan 30, 2024
1 parent 354afb9 commit 3b34616
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 33 deletions.
114 changes: 107 additions & 7 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,45 @@
from ring.coroutine.v1 import coroutine_pb2


# Most types in this package are thin wrappers around the various protobuf
# messages of ring.coroutine.v1. They provide some safeguards and ergonomics.


class Coroutine:
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK."""

def __init__(self, func):
self._func = func

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)

@property
def uri(self) -> str:
return self._func.__qualname__

def call_with(self, input: Any, correlation_id: int | None = None) -> Call:
"""Create a Call of this coroutine with the provided input. Useful to
generate calls during callbacks.
Args:
input: any pickle-able python value that will be passed as input to
this coroutine.
correlation_id: optional arbitrary integer the caller can use to
match this call to a call result.
Returns:
A Call object. It can likely be passed to Output.callback().
"""
return Call(
coroutine_uri=self.uri,
coroutine_version="v1",
correlation_id=correlation_id,
input=input,
)


class Input:
"""The input to a coroutine.
Expand Down Expand Up @@ -50,6 +89,7 @@ def __init__(self, req: coroutine_pb2.ExecuteRequest):
self._state = pickle.loads(state_bytes)
else:
self._state = None
self._calls = [CallResult(r) for r in req.poll_response.results]

@property
def is_first_call(self) -> bool:
Expand All @@ -61,15 +101,26 @@ def is_resume(self) -> bool:

@property
def input(self) -> Any:
if self.is_resume:
raise ValueError("This input is for a resumed coroutine")
self._assert_first_call()
return self._input

@property
def state(self) -> Any:
self._assert_resume()
return self._state

@property
def calls(self) -> Any:
self._assert_resume()
return self._calls

def _assert_first_call(self):
if self.is_resume:
raise ValueError("This input is for a resumed coroutine")

def _assert_resume(self):
if self.is_first_call:
raise ValueError("This input is for a first coroutine call")
return self._state


class Output:
Expand All @@ -93,14 +144,59 @@ def value(cls, value: Any) -> Output:
)

@classmethod
def callback(cls, state: Any) -> Output:
def callback(cls, state: Any, calls: None | list[Call] = None) -> Output:
"""Exit the coroutine instructing the orchestrator to call back this
coroutine with the provided state. The state will be made available in
Input.state."""
state_bytes = pickle.dumps(state)
return Output(
coroutine_pb2.ExecuteResponse(poll=coroutine_pb2.Poll(state=state_bytes))
)
poll = coroutine_pb2.Poll(state=state_bytes)

if calls is not None:
for c in calls:
input_bytes = _pb_any_pickle(c.input)
x = coroutine_pb2.Call(
coroutine_uri=c.coroutine_uri,
coroutine_version=c.coroutine_version,
correlation_id=c.correlation_id,
input=input_bytes,
)
poll.calls.append(x)

return Output(coroutine_pb2.ExecuteResponse(poll=poll))


# Note: contrary to other classes here Call is not just a wrapper around its
# associated protobuf class, because it is reasonable for a human to write the
# Call manually -- for example to call a coroutine that cannot be referenced in
# the current Python process.


@dataclass
class Call:
"""Instruction to invoke a coroutine.
Though this class can be built manually, it is recommended to use the
with_call method of a Coroutine instead.
"""

coroutine_uri: str
coroutine_version: str
correlation_id: int | None
input: Any


class CallResult:
def __init__(self, proto: coroutine_pb2.CallResult):
self.coroutine_uri = proto.coroutine_uri
self.coroutine_version = proto.coroutine_version
self.correlation_id = proto.correlation_id
self.result = _any_unpickle(proto.result.output)


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue())
return pickle.loads(value_bytes.value)


def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
Expand All @@ -109,3 +205,7 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any


def _coroutine_uri_to_qualname(coroutine_uri: str) -> str:
return coroutine_uri.split("/")[-1]
21 changes: 9 additions & 12 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,12 @@ def dispatch_coroutine(self):
ValueError: If the coroutine is already registered.
"""

def wrap(coroutine: Callable[..., Any]):
if coroutine.__qualname__ in self._coroutines:
raise ValueError(
f"Coroutine {coroutine.__qualname__} already registered"
)
self._coroutines[coroutine.__qualname__] = coroutine
return coroutine
def wrap(func: Callable[[dispatch.coroutine.Input], dispatch.coroutine.Output]):
coro = dispatch.coroutine.Coroutine(func)
if coro.uri in self._coroutines:
raise ValueError(f"Coroutine {coro.uri} already registered")
self._coroutines[coro.uri] = coro
return coro

return wrap

Expand All @@ -93,10 +92,6 @@ class _GRPCResponse(fastapi.Response):
media_type = "application/grpc+proto"


def _coroutine_uri_to_qualname(coroutine_uri: str) -> str:
return coroutine_uri.split("/")[-1]


def _new_app():
app = _DispatchAPI()
app._coroutines = {}
Expand All @@ -122,7 +117,9 @@ async def execute(request: fastapi.Request):

# TODO: be more graceful. This will crash if the coroutine is not found,
# and the coroutine version is not taken into account.
coroutine = app._coroutines[_coroutine_uri_to_qualname(req.coroutine_uri)]
coroutine = app._coroutines[
dispatch.coroutine._coroutine_uri_to_qualname(req.coroutine_uri)
]

coro_input = dispatch.coroutine.Input(req)
output = coroutine(coro_input)
Expand Down
88 changes: 74 additions & 14 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import fastapi
from fastapi.testclient import TestClient
import google.protobuf.wrappers_pb2
import ring.coroutine.v1.coroutine_pb2
from ring.coroutine.v1 import coroutine_pb2
from . import executor_service


Expand Down Expand Up @@ -63,15 +63,15 @@ def my_cool_coroutine(input: Input) -> Output:
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
coroutine_uri=my_cool_coroutine.__qualname__,
req = coroutine_pb2.ExecuteRequest(
coroutine_uri=my_cool_coroutine.uri,
coroutine_version="1",
input=input_any,
)

resp = client.Execute(req)

self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)
self.assertIsInstance(resp, coroutine_pb2.ExecuteResponse)
self.assertEqual(resp.coroutine_uri, req.coroutine_uri)
self.assertEqual(resp.coroutine_version, req.coroutine_version)

Expand All @@ -83,11 +83,8 @@ def my_cool_coroutine(input: Input) -> Output:
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")


def response_output(resp: ring.coroutine.v1.coroutine_pb2.ExecuteResponse) -> Any:
resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
return pickle.loads(output_bytes.value)
def response_output(resp: coroutine_pb2.ExecuteResponse) -> Any:
return dispatch.coroutine._any_unpickle(resp.exit.result.output)


class TestCoroutine(unittest.TestCase):
Expand All @@ -97,11 +94,11 @@ def setUp(self):
self.client = executor_service.client(http_client)

def execute(
self, coroutine, input=None, state=None
) -> ring.coroutine.v1.coroutine_pb2.ExecuteResponse:
self, coroutine, input=None, state=None, calls=None
) -> coroutine_pb2.ExecuteResponse:
"""Test helper to invoke coroutines on the local server."""
req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
coroutine_uri=coroutine.__qualname__,
req = coroutine_pb2.ExecuteRequest(
coroutine_uri=coroutine.uri,
coroutine_version="1",
)

Expand All @@ -112,11 +109,34 @@ def execute(
req.input.CopyFrom(input_any)
if state is not None:
req.poll_response.state = state
if calls is not None:
for c in calls:
req.poll_response.results.append(c)

resp = self.client.Execute(req)
self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)
self.assertIsInstance(resp, coroutine_pb2.ExecuteResponse)
return resp

def call(self, call: coroutine_pb2.Call) -> coroutine_pb2.CallResult:
req = coroutine_pb2.ExecuteRequest(
coroutine_uri=call.coroutine_uri,
coroutine_version=call.coroutine_version,
input=call.input,
)
resp = self.client.Execute(req)
self.assertIsInstance(resp, coroutine_pb2.ExecuteResponse)

# Assert the response is terminal. Good enough until the test client can
# orchestrate coroutines.
self.assertTrue(len(resp.poll.state) == 0)

return coroutine_pb2.CallResult(
coroutine_uri=resp.coroutine_uri,
coroutine_version=resp.coroutine_version,
correlation_id=call.correlation_id,
result=resp.exit.result,
)

def test_no_input(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
Expand Down Expand Up @@ -186,3 +206,43 @@ def coroutine3(input: Input) -> Output:
self.assertTrue(len(state) == 0)
out = response_output(resp)
self.assertEqual(out, "done")

def test_coroutine_poll(self):
@self.app.dispatch_coroutine()
def coro_compute_len(input: Input) -> Output:
return Output.value(len(input.input))

@self.app.dispatch_coroutine()
def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
return Output.callback(
state=text, calls=[coro_compute_len.call_with(text)]
)
text = input.state
length = input.calls[0].result
return Output.value(f"length={length} text='{text}'")

resp = self.execute(coroutine_main, input="cool stuff")

# main saved some state
state = resp.poll.state
self.assertTrue(len(state) > 0)
# main asks for 1 call to compute_len
self.assertEqual(len(resp.poll.calls), 1)
call = resp.poll.calls[0]
self.assertEqual(call.coroutine_uri, coro_compute_len.uri)
self.assertEqual(dispatch.coroutine._any_unpickle(call.input), "cool stuff")

# make the requested compute_len
resp2 = self.call(call)
# check the result is the terminal expected response
len_resp = dispatch.coroutine._any_unpickle(resp2.result.output)
self.assertEqual(10, len_resp)

# resume main with the result
resp = self.execute(coroutine_main, state=state, calls=[resp2])
# validate the final result
self.assertTrue(len(resp.poll.state) == 0)
out = response_output(resp)
self.assertEqual("length=10 text='cool stuff'", out)

0 comments on commit 3b34616

Please sign in to comment.