From d6fa55596a258885deb98b78e9110a55291defe4 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 14:48:49 -0700 Subject: [PATCH] emulate wait capability Signed-off-by: Achille Roussel --- examples/github_stats/app.py | 27 +++++----------- examples/github_stats/test_app.py | 52 ------------------------------- src/dispatch/__init__.py | 49 ++++++++++++++++++----------- src/dispatch/function.py | 46 +++++++++++++++++---------- src/dispatch/http.py | 15 +++++++-- src/dispatch/test/__init__.py | 33 ++++++++++---------- 6 files changed, 96 insertions(+), 126 deletions(-) delete mode 100644 examples/github_stats/test_app.py diff --git a/examples/github_stats/app.py b/examples/github_stats/app.py index 513bb24..996743d 100644 --- a/examples/github_stats/app.py +++ b/examples/github_stats/app.py @@ -14,16 +14,9 @@ """ +import dispatch import httpx -from fastapi import FastAPI - from dispatch.error import ThrottleError -from dispatch.fastapi import Dispatch - -app = FastAPI() - -dispatch = Dispatch(app) - def get_gh_api(url): print(f"GET {url}") @@ -36,21 +29,21 @@ def get_gh_api(url): @dispatch.function -async def get_repo_info(repo_owner: str, repo_name: str): +async def get_repo_info(repo_owner: str, repo_name: str) -> dict: url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" repo_info = get_gh_api(url) return repo_info @dispatch.function -async def get_contributors(repo_info: dict): +async def get_contributors(repo_info: dict) -> list[dict]: url = repo_info["contributors_url"] contributors = get_gh_api(url) return contributors @dispatch.function -async def main(): +async def main() -> list[dict]: repo_info = await get_repo_info("dispatchrun", "coroutine") print( f"""Repository: {repo_info['full_name']} @@ -58,13 +51,9 @@ async def main(): Watchers: {repo_info['watchers_count']} Forks: {repo_info['forks_count']}""" ) - - contributors = await get_contributors(repo_info) - print(f"Contributors: {len(contributors)}") - return + return await get_contributors(repo_info) -@app.get("/") -def root(): - main.dispatch() - return "OK" +if __name__ == "__main__": + contributors = dispatch.run(main()) + print(f"Contributors: {len(contributors)}") diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py deleted file mode 100644 index 37ca0d8..0000000 --- a/examples/github_stats/test_app.py +++ /dev/null @@ -1,52 +0,0 @@ -# This file is not part of the example. It is a test file to ensure the example -# works as expected during the CI process. - - -import os -import unittest -from unittest import mock - -from dispatch.function import Client -from dispatch.test import DispatchServer, DispatchService, EndpointClient -from dispatch.test.fastapi import http_client - - -class TestGithubStats(unittest.TestCase): - @mock.patch.dict( - os.environ, - { - "DISPATCH_ENDPOINT_URL": "http://function-service", - "DISPATCH_API_KEY": "0000000000000000", - }, - ) - def test_app(self): - from .app import app, dispatch - - # Setup a fake Dispatch server. - app_client = http_client(app) - endpoint_client = EndpointClient(app_client) - dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) - with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. - dispatch.registry.client = Client(api_url=dispatch_server.url) - - response = app_client.get("/") - self.assertEqual(response.status_code, 200) - - while dispatch_service.queue: - dispatch_service.dispatch_calls() - - # Three unique functions were called, with five total round-trips. - # The main function is called initially, and then polls - # twice, for three total round-trips. There's one round-trip - # to get_repo_info and one round-trip to get_contributors. - self.assertEqual( - 3, len(dispatch_service.roundtrips) - ) # 3 unique functions were called - self.assertEqual( - 5, - sum( - len(roundtrips) - for roundtrips in dispatch_service.roundtrips.values() - ), - ) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index bd50a27..ea0090a 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import os from http.server import ThreadingHTTPServer from typing import Any, Callable, Coroutine, Optional, TypeVar, overload @@ -20,7 +21,7 @@ Reset, default_registry, ) -from dispatch.http import Dispatch +from dispatch.http import Dispatch, Server from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status @@ -63,7 +64,21 @@ def function(func): return default_registry().function(func) -def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwargs): +async def main(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: + address = addr or str(os.environ.get("DISPATCH_ENDPOINT_ADDR")) or "localhost:8000" + parsed_url = urlsplit("//" + address) + + host = parsed_url.hostname or "" + port = parsed_url.port or 0 + + reg = default_registry() + app = Dispatch(reg) + + async with Server(host, port, app) as server: + return await coro + + +def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: """Run the default dispatch server. The default server uses a function registry where functions tagged by the `@dispatch.function` decorator are registered. @@ -73,27 +88,23 @@ def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwa to the Dispatch bridge API. Args: - init: An initialization function called after binding the server address - but before entering the event loop to handle requests. - - args: Positional arguments to pass to the entrypoint. + coro: The coroutine to run as the entrypoint, the function returns + when the coroutine returns. - kwargs: Keyword arguments to pass to the entrypoint. + addr: The address to bind the server to. If not provided, the server + will bind to the address specified by the `DISPATCH_ENDPOINT_ADDR` + environment variable. If the environment variable is not set, the + server will bind to `localhost:8000`. Returns: - The return value of the entrypoint function. + The value returned by the coroutine. """ - address = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000") - parsed_url = urlsplit("//" + address) - server_address = (parsed_url.hostname or "", parsed_url.port or 0) - server = ThreadingHTTPServer(server_address, Dispatch(default_registry())) - try: - if init is not None: - init(*args, **kwargs) - server.serve_forever() - finally: - server.shutdown() - server.server_close() + return asyncio.run(main(coro, addr)) + + +def run_forever(): + """Run the default dispatch server forever.""" + return run(asyncio.Event().wait()) def batch() -> Batch: diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 404225d..98442d0 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -382,6 +382,10 @@ def set_default_registry(reg: Registry): DEFAULT_REGISTRY_NAME = reg.name +# TODO: this is a temporary solution to track inflight tasks and allow waiting +# for results. +_calls: Dict[str, asyncio.Future] = {} + class Client: """Client for the Dispatch API.""" @@ -469,6 +473,11 @@ async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: resp = dispatch_pb.DispatchResponse() resp.ParseFromString(data) + # TODO: remove when we implemented the wait endpoint in the server + for dispatch_id in resp.dispatch_ids: + if dispatch_id not in _calls: + _calls[dispatch_id] = asyncio.Future() + dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -479,23 +488,26 @@ async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: return dispatch_ids async def wait(self, dispatch_id: DispatchID) -> Any: - (url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait") - data = dispatch_id.encode("utf-8") - - async with self.session() as session: - async with session.post( - url, headers=headers, data=data, timeout=timeout - ) as res: - data = await res.read() - self._check_response(res.status, data) - - resp = call_pb.CallResult() - resp.ParseFromString(data) - - result = CallResult._from_proto(resp) - if result.error is not None: - raise result.error.to_exception() - return result.output + # (url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait") + # data = dispatch_id.encode("utf-8") + + # async with self.session() as session: + # async with session.post( + # url, headers=headers, data=data, timeout=timeout + # ) as res: + # data = await res.read() + # self._check_response(res.status, data) + + # resp = call_pb.CallResult() + # resp.ParseFromString(data) + + # result = CallResult._from_proto(resp) + # if result.error is not None: + # raise result.error.to_exception() + # return result.output + + future = _calls[dispatch_id] + return await future def _check_response(self, status: int, data: bytes): if status == 200: diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 14475ce..b9937d0 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -20,8 +20,8 @@ from http_message_signatures import InvalidSignature from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Batch, Function, Registry, default_registry -from dispatch.proto import Input +from dispatch.function import Batch, Function, Registry, default_registry, _calls +from dispatch.proto import CallResult, Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( CaseInsensitiveDict, @@ -78,7 +78,7 @@ def function(self, func): def batch(self) -> Batch: return self.registry.batch() - async def run(self, url, method, headers, data): + async def run(self, url: str, method: str, headers: Mapping[str, str], data: bytes) -> bytes: return await function_service_run( url, method, @@ -380,6 +380,9 @@ async def function_service_run( response = output._message status = Status(response.status) + if req.dispatch_id not in _calls: + _calls[req.dispatch_id] = asyncio.Future() + if response.HasField("poll"): logger.debug( "function '%s' polling with %d call(s)", @@ -392,6 +395,12 @@ async def function_service_run( logger.debug("function '%s' exiting with no result", req.function) else: result = exit.result + call_result = CallResult._from_proto(result) + call_future = _calls[req.dispatch_id] + if call_result.error is not None: + call_future.set_exception(call_result.error.to_exception()) + else: + call_future.set_result(call_result.output) if result.HasField("output"): logger.debug("function '%s' exiting with output value", req.function) elif result.HasField("error"): diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index ad58684..1b93ff4 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -58,6 +58,7 @@ DISPATCH_API_URL = "http://127.0.0.1:0" DISPATCH_API_KEY = "916CC3D280BB46DDBDA984B3DD10059A" +_dispatch_ids = (str(i) for i in range(2**32 - 1)) class Client(BaseClient): def session(self) -> aiohttp.ClientSession: @@ -75,14 +76,12 @@ def __init__(self, app: web.Application): def url(self): return f"http://{self.host}:{self.port}" - class Service(web.Application): tasks: Dict[str, asyncio.Task] _session: Optional[aiohttp.ClientSession] = None def __init__(self, session: Optional[aiohttp.ClientSession] = None): super().__init__() - self.dispatch_ids = (str(i) for i in range(2**32 - 1)) self.tasks = {} self.add_routes( [ @@ -126,7 +125,7 @@ async def handle_wait_request(self, request: web.Request): ) async def dispatch(self, req: DispatchRequest) -> DispatchResponse: - dispatch_ids = [next(self.dispatch_ids) for _ in req.calls] + dispatch_ids = [next(_dispatch_ids) for _ in req.calls] for call, dispatch_id in zip(req.calls, dispatch_ids): self.tasks[dispatch_id] = asyncio.create_task( @@ -208,19 +207,21 @@ def make_request(call: Call) -> RunRequest: ) # TODO: enforce poll limits - results = await asyncio.gather( - *[ - self.call( - call=subcall, - dispatch_id=subcall_dispatch_id, - parent_dispatch_id=dispatch_id, - root_dispatch_id=root_dispatch_id, - ) - for subcall, subcall_dispatch_id in zip( - res.poll.calls, next(self.dispatch_ids) - ) - ] - ) + subcall_dispatch_ids = [next(_dispatch_ids) for _ in res.poll.calls] + + subcalls = [ + self.call( + call=subcall, + dispatch_id=subcall_dispatch_id, + parent_dispatch_id=dispatch_id, + root_dispatch_id=root_dispatch_id, + ) + for subcall, subcall_dispatch_id in zip( + res.poll.calls, subcall_dispatch_ids + ) + ] + + results = await asyncio.gather(*subcalls) req = RunRequest( function=req.function,