From e340938dfcaeb4bca36cd405d0579669a26234b0 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 16:24:36 -0700 Subject: [PATCH] support both asyncio and blocking modes with different abstractions Signed-off-by: Achille Roussel --- src/dispatch/__init__.py | 2 +- .../{asyncio.py => asyncio/__init__.py} | 0 src/dispatch/asyncio/fastapi.py | 108 ++++++++++++++++++ src/dispatch/experimental/lambda_handler.py | 4 +- src/dispatch/fastapi.py | 95 +++------------ src/dispatch/flask.py | 8 +- src/dispatch/function.py | 54 +++++++-- src/dispatch/http.py | 51 +++++++-- src/dispatch/test.py | 7 +- tests/test_fastapi.py | 11 +- tests/test_http.py | 2 +- 11 files changed, 220 insertions(+), 122 deletions(-) rename src/dispatch/{asyncio.py => asyncio/__init__.py} (100%) create mode 100644 src/dispatch/asyncio/fastapi.py diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 95f95ca..08ad5fe 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -12,11 +12,11 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race +from dispatch.function import AsyncFunction as Function from dispatch.function import ( Batch, Client, ClientError, - Function, Registry, Reset, default_registry, diff --git a/src/dispatch/asyncio.py b/src/dispatch/asyncio/__init__.py similarity index 100% rename from src/dispatch/asyncio.py rename to src/dispatch/asyncio/__init__.py diff --git a/src/dispatch/asyncio/fastapi.py b/src/dispatch/asyncio/fastapi.py new file mode 100644 index 0000000..d7cf9f0 --- /dev/null +++ b/src/dispatch/asyncio/fastapi.py @@ -0,0 +1,108 @@ +"""Integration of Dispatch functions with FastAPI for handlers using asyncio. + +Example: + + import fastapi + from dispatch.asyncio.fastapi import Dispatch + + app = fastapi.FastAPI() + dispatch = Dispatch(app) + + @dispatch.function + def my_function(): + return "Hello World!" + + @app.get("/") + async def read_root(): + await my_function.dispatch() +""" + +import logging +from typing import Optional, Union + +import fastapi +import fastapi.responses + +from dispatch.function import Registry +from dispatch.http import ( + AsyncFunctionService, + FunctionServiceError, + validate_content_length, +) +from dispatch.signature import Ed25519PublicKey, parse_verification_key + +logger = logging.getLogger(__name__) + + +class Dispatch(AsyncFunctionService): + """A Dispatch instance, powered by FastAPI.""" + + def __init__( + self, + app: fastapi.FastAPI, + registry: Optional[Registry] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + ): + """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. + + It mounts a sub-app that implements the Dispatch gRPC interface. + + Args: + app: The FastAPI app to configure. + + registry: A registry of functions to expose. If omitted, the default + registry is used. + + verification_key: Key to use when verifying signed requests. Uses + the value of the DISPATCH_VERIFICATION_KEY environment variable + if omitted. The environment variable is expected to carry an + Ed25519 public key in base64 or PEM format. + If not set, request signature verification is disabled (a warning + will be logged by the constructor). + + Raises: + ValueError: If any of the required arguments are missing. + """ + if not app: + raise ValueError( + "missing FastAPI app as first argument of the Dispatch constructor" + ) + super().__init__(registry, verification_key) + function_service = fastapi.FastAPI() + + @function_service.exception_handler(FunctionServiceError) + async def on_error(request: fastapi.Request, exc: FunctionServiceError): + # https://connectrpc.com/docs/protocol/#error-end-stream + return fastapi.responses.JSONResponse( + status_code=exc.status, + content={"code": exc.code, "message": exc.message}, + ) + + @function_service.post( + # The endpoint for execution is hardcoded at the moment. If the service + # gains more endpoints, this should be turned into a dynamic dispatch + # like the official gRPC server does. + "/Run", + ) + async def run(request: fastapi.Request): + valid, reason = validate_content_length( + int(request.headers.get("content-length", 0)) + ) + if not valid: + raise FunctionServiceError(400, "invalid_argument", reason) + + # Raw request body bytes are only available through the underlying + # starlette Request object's body method, which returns an awaitable, + # forcing execute() to be async. + data: bytes = await request.body() + + content = await self.run( + str(request.url), + request.method, + request.headers, + await request.body(), + ) + + return fastapi.Response(content=content, media_type="application/proto") + + app.mount("/dispatch.sdk.v1.FunctionService", function_service) diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 8990c6a..01a0968 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -27,7 +27,7 @@ def handler(event, context): from awslambdaric.lambda_context import LambdaContext from dispatch.function import Registry -from dispatch.http import FunctionService +from dispatch.http import BlockingFunctionService from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.status import Status @@ -35,7 +35,7 @@ def handler(event, context): logger = logging.getLogger(__name__) -class Dispatch(FunctionService): +class Dispatch(BlockingFunctionService): def __init__( self, registry: Optional[Registry] = None, diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 3abf7b1..7bf75f1 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -15,90 +15,29 @@ def my_function(): @app.get("/") def read_root(): my_function.dispatch() - """ +""" -import logging -from typing import Optional, Union +from typing import Any, Callable, Coroutine, TypeVar, overload -import fastapi -import fastapi.responses +from typing_extensions import ParamSpec -from dispatch.function import Registry -from dispatch.http import FunctionService, FunctionServiceError, validate_content_length -from dispatch.signature import Ed25519PublicKey, parse_verification_key +from dispatch.asyncio.fastapi import Dispatch as AsyncDispatch +from dispatch.function import BlockingFunction -logger = logging.getLogger(__name__) +__all__ = ["Dispatch", "AsyncDispatch"] +P = ParamSpec("P") +T = TypeVar("T") -class Dispatch(FunctionService): - """A Dispatch instance, powered by FastAPI.""" - def __init__( - self, - app: fastapi.FastAPI, - registry: Optional[Registry] = None, - verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - ): - """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. +class Dispatch(AsyncDispatch): + @overload # type: ignore + def function(self, func: Callable[P, T]) -> BlockingFunction[P, T]: ... - It mounts a sub-app that implements the Dispatch gRPC interface. + @overload # type: ignore + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> BlockingFunction[P, T]: ... - Args: - app: The FastAPI app to configure. - - registry: A registry of functions to expose. If omitted, the default - registry is used. - - verification_key: Key to use when verifying signed requests. Uses - the value of the DISPATCH_VERIFICATION_KEY environment variable - if omitted. The environment variable is expected to carry an - Ed25519 public key in base64 or PEM format. - If not set, request signature verification is disabled (a warning - will be logged by the constructor). - - Raises: - ValueError: If any of the required arguments are missing. - """ - if not app: - raise ValueError( - "missing FastAPI app as first argument of the Dispatch constructor" - ) - super().__init__(registry, verification_key) - function_service = fastapi.FastAPI() - - @function_service.exception_handler(FunctionServiceError) - async def on_error(request: fastapi.Request, exc: FunctionServiceError): - # https://connectrpc.com/docs/protocol/#error-end-stream - return fastapi.responses.JSONResponse( - status_code=exc.status, - content={"code": exc.code, "message": exc.message}, - ) - - @function_service.post( - # The endpoint for execution is hardcoded at the moment. If the service - # gains more endpoints, this should be turned into a dynamic dispatch - # like the official gRPC server does. - "/Run", - ) - async def run(request: fastapi.Request): - valid, reason = validate_content_length( - int(request.headers.get("content-length", 0)) - ) - if not valid: - raise FunctionServiceError(400, "invalid_argument", reason) - - # Raw request body bytes are only available through the underlying - # starlette Request object's body method, which returns an awaitable, - # forcing execute() to be async. - data: bytes = await request.body() - - content = await self.run( - str(request.url), - request.method, - request.headers, - await request.body(), - ) - - return fastapi.Response(content=content, media_type="application/proto") - - app.mount("/dispatch.sdk.v1.FunctionService", function_service) + def function(self, func): + return BlockingFunction(super().function(func)) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index ffd6c92..0724b89 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -24,13 +24,17 @@ def read_root(): from flask import Flask, make_response, request from dispatch.function import Registry -from dispatch.http import FunctionService, FunctionServiceError, validate_content_length +from dispatch.http import ( + BlockingFunctionService, + FunctionServiceError, + validate_content_length, +) from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) -class Dispatch(FunctionService): +class Dispatch(BlockingFunctionService): """A Dispatch instance, powered by Flask.""" def __init__( diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 5492581..a685a1d 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -18,6 +18,7 @@ Optional, Tuple, TypeVar, + Union, overload, ) from urllib.parse import urlparse @@ -111,7 +112,7 @@ def _build_primitive_call( ) -class Function(PrimitiveFunction, Generic[P, T]): +class AsyncFunction(PrimitiveFunction, Generic[P, T]): """Callable wrapper around a function meant to be used throughout the Dispatch Python SDK. """ @@ -157,7 +158,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: else: return self._call_dispatch(*args, **kwargs) - def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: + async def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """Dispatch an asynchronous call to the function without waiting for a result. @@ -171,7 +172,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: Returns: DispatchID: ID of the dispatched call. """ - return asyncio.run(self._primitive_dispatch(Arguments(args, kwargs))) + return await self._primitive_dispatch(Arguments(args, kwargs)) def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: """Create a Call for this function with the provided input. Useful to @@ -187,11 +188,38 @@ def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: return self._build_primitive_call(Arguments(args, kwargs)) +class BlockingFunction(Generic[P, T]): + """BlockingFunction is like Function but exposes a blocking API instead of + functions that use asyncio. + + Applications typically don't create instances of BlockingFunction directly, + and instead use decorators from packages that provide integrations with + Python frameworks. + """ + + def __init__(self, func: AsyncFunction[P, T]): + self._func = func + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return asyncio.run(self._func(*args, **kwargs)) + + def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: + return asyncio.run(self._func.dispatch(*args, **kwargs)) + + def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: + return self._func.build_call(*args, **kwargs) + + class Reset(TailCall): """The current coroutine is aborted and scheduling reset to be replaced with the call embedded in this exception.""" - def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): + def __init__( + self, + func: Union[AsyncFunction[P, T], BlockingFunction[P, T]], + *args: P.args, + **kwargs: P.kwargs, + ): super().__init__(call=func.build_call(*args, **kwargs)) @@ -267,10 +295,12 @@ def endpoint(self, value: str): self._endpoint = value @overload - def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> AsyncFunction[P, T]: ... @overload - def function(self, func: Callable[P, T]) -> Function[P, T]: ... + def function(self, func: Callable[P, T]) -> AsyncFunction[P, T]: ... def function(self, func): """Decorator that registers functions.""" @@ -283,7 +313,9 @@ def function(self, func): logger.info("registering coroutine: %s", name) return self._register_coroutine(name, func) - def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: + def _register_function( + self, name: str, func: Callable[P, T] + ) -> AsyncFunction[P, T]: func = durable(func) @wraps(func) @@ -296,7 +328,7 @@ async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def _register_coroutine( self, name: str, func: Callable[P, Coroutine[Any, Any, T]] - ) -> Function[P, T]: + ) -> AsyncFunction[P, T]: logger.info("registering coroutine: %s", name) func = durable(func) @@ -307,7 +339,7 @@ async def primitive_func(input: Input) -> Output: primitive_func.__qualname__ = f"{name}_primitive" durable_primitive_func = durable(primitive_func) - wrapped_func = Function[P, T]( + wrapped_func = AsyncFunction[P, T]( self, name, durable_primitive_func, @@ -555,7 +587,9 @@ def __init__(self, client: Client): self.client = client self.calls = [] - def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch: + def add( + self, func: AsyncFunction[P, T], *args: P.args, **kwargs: P.kwargs + ) -> Batch: """Add a call to the specified function to the batch.""" return self.add_call(func.build_call(*args, **kwargs)) diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 591642b..e56cb26 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -20,7 +20,14 @@ from http_message_signatures import InvalidSignature from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Batch, Function, Registry, _calls, default_registry +from dispatch.function import ( + AsyncFunction, + Batch, + BlockingFunction, + Registry, + _calls, + default_registry, +) from dispatch.proto import CallResult, Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -41,7 +48,7 @@ T = TypeVar("T") -class FunctionService: +class BaseFunctionService: """FunctionService is an abstract class intended to be inherited by objects that integrate dispatch with other server application frameworks. @@ -68,16 +75,6 @@ def registry(self) -> Registry: def verification_key(self) -> Optional[Ed25519PublicKey]: return self._verification_key - @overload - def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... - - @overload - def function(self, func: Callable[P, T]) -> Function[P, T]: ... - - def function(self, func): - """Decorator that registers functions.""" - return self.registry.function(func) - def batch(self) -> Batch: """Create a new batch.""" return self.registry.batch() @@ -95,6 +92,36 @@ async def run( ) +class AsyncFunctionService(BaseFunctionService): + @overload + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> AsyncFunction[P, T]: ... + + @overload + def function(self, func: Callable[P, T]) -> AsyncFunction[P, T]: ... + + def function(self, func): + return self.registry.function(func) + + +class BlockingFunctionService(BaseFunctionService): + """BlockingFunctionService is a variant of FunctionService which decorates + dispatch functions with a synchronous API instead of using asyncio. + """ + + @overload + def function(self, func: Callable[P, T]) -> BlockingFunction[P, T]: ... + + @overload + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> BlockingFunction[P, T]: ... + + def function(self, func): + return BlockingFunction(self.registry.function(func)) + + class FunctionServiceError(Exception): __slots__ = ("status", "code", "message") diff --git a/src/dispatch/test.py b/src/dispatch/test.py index b3768d1..7c8291b 100644 --- a/src/dispatch/test.py +++ b/src/dispatch/test.py @@ -21,7 +21,7 @@ default_registry, set_default_registry, ) -from dispatch.http import Dispatch, FunctionService +from dispatch.http import Dispatch from dispatch.http import Server as BaseServer from dispatch.sdk.v1.call_pb2 import Call, CallResult from dispatch.sdk.v1.dispatch_pb2 import DispatchRequest, DispatchResponse @@ -290,19 +290,14 @@ async def main(coro: Coroutine[Any, Any, None]) -> None: api = Service() app = Dispatch(reg) try: - print("Starting bakend") async with Server(api) as backend: - print("Starting server") async with Server(app) as server: # Here we break through the abstraction layers a bit, it's not # ideal but it works for now. reg.client.api_url.value = backend.url reg.endpoint = server.url - print("BACKEND:", backend.url) - print("SERVER:", server.url) await coro finally: - print("DONE!") await api.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 2af40d3..97e9766 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,7 +1,6 @@ import asyncio import socket import sys -from typing import Any, Optional import fastapi import google.protobuf.any_pb2 @@ -19,15 +18,7 @@ from dispatch.asyncio import Runner from dispatch.experimental.durable.registry import clear_functions from dispatch.fastapi import Dispatch -from dispatch.function import ( - Arguments, - Client, - Error, - Function, - Input, - Output, - Registry, -) +from dispatch.function import Arguments, Client, Error, Input, Output, Registry from dispatch.proto import _any_unpickle as any_unpickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb diff --git a/tests/test_http.py b/tests/test_http.py index cba4742..cdfc277 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -6,7 +6,7 @@ import dispatch.test from dispatch.asyncio import Runner from dispatch.function import Registry -from dispatch.http import Dispatch, FunctionService, Server +from dispatch.http import Dispatch, Server class TestHTTP(dispatch.test.TestCase):