Skip to content

Commit

Permalink
emulate wait capability
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
  • Loading branch information
achille-roussel committed Jun 17, 2024
1 parent c053298 commit d6fa555
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 126 deletions.
27 changes: 8 additions & 19 deletions examples/github_stats/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,9 @@
"""

import dispatch
import httpx
from fastapi import FastAPI

from dispatch.error import ThrottleError
from dispatch.fastapi import Dispatch

app = FastAPI()

dispatch = Dispatch(app)


def get_gh_api(url):
print(f"GET {url}")
Expand All @@ -36,35 +29,31 @@ def get_gh_api(url):


@dispatch.function
async def get_repo_info(repo_owner: str, repo_name: str):
async def get_repo_info(repo_owner: str, repo_name: str) -> dict:
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}"
repo_info = get_gh_api(url)
return repo_info


@dispatch.function
async def get_contributors(repo_info: dict):
async def get_contributors(repo_info: dict) -> list[dict]:
url = repo_info["contributors_url"]
contributors = get_gh_api(url)
return contributors


@dispatch.function
async def main():
async def main() -> list[dict]:
repo_info = await get_repo_info("dispatchrun", "coroutine")
print(
f"""Repository: {repo_info['full_name']}
Stars: {repo_info['stargazers_count']}
Watchers: {repo_info['watchers_count']}
Forks: {repo_info['forks_count']}"""
)

contributors = await get_contributors(repo_info)
print(f"Contributors: {len(contributors)}")
return
return await get_contributors(repo_info)


@app.get("/")
def root():
main.dispatch()
return "OK"
if __name__ == "__main__":
contributors = dispatch.run(main())
print(f"Contributors: {len(contributors)}")
52 changes: 0 additions & 52 deletions examples/github_stats/test_app.py

This file was deleted.

49 changes: 30 additions & 19 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import os
from http.server import ThreadingHTTPServer
from typing import Any, Callable, Coroutine, Optional, TypeVar, overload
Expand All @@ -20,7 +21,7 @@
Reset,
default_registry,
)
from dispatch.http import Dispatch
from dispatch.http import Dispatch, Server
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
from dispatch.status import Status
Expand Down Expand Up @@ -63,7 +64,21 @@ def function(func):
return default_registry().function(func)


def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwargs):
async def main(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T:
address = addr or str(os.environ.get("DISPATCH_ENDPOINT_ADDR")) or "localhost:8000"
parsed_url = urlsplit("//" + address)

host = parsed_url.hostname or ""
port = parsed_url.port or 0

reg = default_registry()
app = Dispatch(reg)

async with Server(host, port, app) as server:
return await coro


def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T:
"""Run the default dispatch server. The default server uses a function
registry where functions tagged by the `@dispatch.function` decorator are
registered.
Expand All @@ -73,27 +88,23 @@ def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwa
to the Dispatch bridge API.
Args:
init: An initialization function called after binding the server address
but before entering the event loop to handle requests.
args: Positional arguments to pass to the entrypoint.
coro: The coroutine to run as the entrypoint, the function returns
when the coroutine returns.
kwargs: Keyword arguments to pass to the entrypoint.
addr: The address to bind the server to. If not provided, the server
will bind to the address specified by the `DISPATCH_ENDPOINT_ADDR`
environment variable. If the environment variable is not set, the
server will bind to `localhost:8000`.
Returns:
The return value of the entrypoint function.
The value returned by the coroutine.
"""
address = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")
parsed_url = urlsplit("//" + address)
server_address = (parsed_url.hostname or "", parsed_url.port or 0)
server = ThreadingHTTPServer(server_address, Dispatch(default_registry()))
try:
if init is not None:
init(*args, **kwargs)
server.serve_forever()
finally:
server.shutdown()
server.server_close()
return asyncio.run(main(coro, addr))


def run_forever():
"""Run the default dispatch server forever."""
return run(asyncio.Event().wait())


def batch() -> Batch:
Expand Down
46 changes: 29 additions & 17 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ def set_default_registry(reg: Registry):
DEFAULT_REGISTRY_NAME = reg.name


# TODO: this is a temporary solution to track inflight tasks and allow waiting
# for results.
_calls: Dict[str, asyncio.Future] = {}

class Client:
"""Client for the Dispatch API."""

Expand Down Expand Up @@ -469,6 +473,11 @@ async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]:
resp = dispatch_pb.DispatchResponse()
resp.ParseFromString(data)

# TODO: remove when we implemented the wait endpoint in the server
for dispatch_id in resp.dispatch_ids:
if dispatch_id not in _calls:
_calls[dispatch_id] = asyncio.Future()

dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
Expand All @@ -479,23 +488,26 @@ async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]:
return dispatch_ids

async def wait(self, dispatch_id: DispatchID) -> Any:
(url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait")
data = dispatch_id.encode("utf-8")

async with self.session() as session:
async with session.post(
url, headers=headers, data=data, timeout=timeout
) as res:
data = await res.read()
self._check_response(res.status, data)

resp = call_pb.CallResult()
resp.ParseFromString(data)

result = CallResult._from_proto(resp)
if result.error is not None:
raise result.error.to_exception()
return result.output
# (url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait")
# data = dispatch_id.encode("utf-8")

# async with self.session() as session:
# async with session.post(
# url, headers=headers, data=data, timeout=timeout
# ) as res:
# data = await res.read()
# self._check_response(res.status, data)

# resp = call_pb.CallResult()
# resp.ParseFromString(data)

# result = CallResult._from_proto(resp)
# if result.error is not None:
# raise result.error.to_exception()
# return result.output

future = _calls[dispatch_id]
return await future

def _check_response(self, status: int, data: bytes):
if status == 200:
Expand Down
15 changes: 12 additions & 3 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from http_message_signatures import InvalidSignature
from typing_extensions import ParamSpec, TypeAlias

from dispatch.function import Batch, Function, Registry, default_registry
from dispatch.proto import Input
from dispatch.function import Batch, Function, Registry, default_registry, _calls
from dispatch.proto import CallResult, Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
CaseInsensitiveDict,
Expand Down Expand Up @@ -78,7 +78,7 @@ def function(self, func):
def batch(self) -> Batch:
return self.registry.batch()

async def run(self, url, method, headers, data):
async def run(self, url: str, method: str, headers: Mapping[str, str], data: bytes) -> bytes:
return await function_service_run(
url,
method,
Expand Down Expand Up @@ -380,6 +380,9 @@ async def function_service_run(
response = output._message
status = Status(response.status)

if req.dispatch_id not in _calls:
_calls[req.dispatch_id] = asyncio.Future()

if response.HasField("poll"):
logger.debug(
"function '%s' polling with %d call(s)",
Expand All @@ -392,6 +395,12 @@ async def function_service_run(
logger.debug("function '%s' exiting with no result", req.function)
else:
result = exit.result
call_result = CallResult._from_proto(result)
call_future = _calls[req.dispatch_id]
if call_result.error is not None:
call_future.set_exception(call_result.error.to_exception())
else:
call_future.set_result(call_result.output)
if result.HasField("output"):
logger.debug("function '%s' exiting with output value", req.function)
elif result.HasField("error"):
Expand Down
33 changes: 17 additions & 16 deletions src/dispatch/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DISPATCH_API_URL = "http://127.0.0.1:0"
DISPATCH_API_KEY = "916CC3D280BB46DDBDA984B3DD10059A"

_dispatch_ids = (str(i) for i in range(2**32 - 1))

class Client(BaseClient):
def session(self) -> aiohttp.ClientSession:
Expand All @@ -75,14 +76,12 @@ def __init__(self, app: web.Application):
def url(self):
return f"http://{self.host}:{self.port}"


class Service(web.Application):
tasks: Dict[str, asyncio.Task]
_session: Optional[aiohttp.ClientSession] = None

def __init__(self, session: Optional[aiohttp.ClientSession] = None):
super().__init__()
self.dispatch_ids = (str(i) for i in range(2**32 - 1))
self.tasks = {}
self.add_routes(
[
Expand Down Expand Up @@ -126,7 +125,7 @@ async def handle_wait_request(self, request: web.Request):
)

async def dispatch(self, req: DispatchRequest) -> DispatchResponse:
dispatch_ids = [next(self.dispatch_ids) for _ in req.calls]
dispatch_ids = [next(_dispatch_ids) for _ in req.calls]

for call, dispatch_id in zip(req.calls, dispatch_ids):
self.tasks[dispatch_id] = asyncio.create_task(
Expand Down Expand Up @@ -208,19 +207,21 @@ def make_request(call: Call) -> RunRequest:
)

# TODO: enforce poll limits
results = await asyncio.gather(
*[
self.call(
call=subcall,
dispatch_id=subcall_dispatch_id,
parent_dispatch_id=dispatch_id,
root_dispatch_id=root_dispatch_id,
)
for subcall, subcall_dispatch_id in zip(
res.poll.calls, next(self.dispatch_ids)
)
]
)
subcall_dispatch_ids = [next(_dispatch_ids) for _ in res.poll.calls]

subcalls = [
self.call(
call=subcall,
dispatch_id=subcall_dispatch_id,
parent_dispatch_id=dispatch_id,
root_dispatch_id=root_dispatch_id,
)
for subcall, subcall_dispatch_id in zip(
res.poll.calls, subcall_dispatch_ids
)
]

results = await asyncio.gather(*subcalls)

req = RunRequest(
function=req.function,
Expand Down

0 comments on commit d6fa555

Please sign in to comment.