From f1337e22877a2315c410a3ed429e7bfecde59b30 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 14:03:23 -0700 Subject: [PATCH] refactor: use composition, default registry, function service Signed-off-by: Achille Roussel --- examples/auto_retry/test_app.py | 2 +- examples/getting_started/test_app.py | 2 +- examples/github_stats/test_app.py | 2 +- src/dispatch/__init__.py | 19 +- src/dispatch/experimental/durable/function.py | 10 +- src/dispatch/experimental/durable/registry.py | 12 + src/dispatch/experimental/lambda_handler.py | 33 +- src/dispatch/fastapi.py | 101 +-- src/dispatch/flask.py | 39 +- src/dispatch/function.py | 166 ++-- src/dispatch/http.py | 75 +- src/dispatch/test/__init__.py | 140 ++-- tests/dispatch/test_function.py | 13 +- tests/test_fastapi.py | 747 +++++++++--------- tests/test_flask.py | 10 +- tests/test_http.py | 24 +- 16 files changed, 739 insertions(+), 656 deletions(-) diff --git a/examples/auto_retry/test_app.py b/examples/auto_retry/test_app.py index fca4dc2..8ce3f18 100644 --- a/examples/auto_retry/test_app.py +++ b/examples/auto_retry/test_app.py @@ -29,7 +29,7 @@ def test_app(self): dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. - dispatch.set_client(Client(api_url=dispatch_server.url)) + dispatch.registry.client = Client(api_url=dispatch_server.url) response = app_client.get("/") self.assertEqual(response.status_code, 200) diff --git a/examples/getting_started/test_app.py b/examples/getting_started/test_app.py index 16a7f8c..a3345b9 100644 --- a/examples/getting_started/test_app.py +++ b/examples/getting_started/test_app.py @@ -28,7 +28,7 @@ def test_app(self): dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. - dispatch.set_client(Client(api_url=dispatch_server.url)) + dispatch.registry.client = Client(api_url=dispatch_server.url) response = app_client.get("/") self.assertEqual(response.status_code, 200) diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py index 4844016..37ca0d8 100644 --- a/examples/github_stats/test_app.py +++ b/examples/github_stats/test_app.py @@ -28,7 +28,7 @@ def test_app(self): dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. - dispatch.set_client(Client(api_url=dispatch_server.url)) + dispatch.registry.client = Client(api_url=dispatch_server.url) response = app_client.get("/") self.assertEqual(response.status_code, 200) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 812d621..788c6a3 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -12,7 +12,15 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import Batch, Client, ClientError, Function, Registry, Reset +from dispatch.function import ( + Batch, + Client, + ClientError, + Function, + Registry, + Reset, + default_registry, +) from dispatch.http import Dispatch from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output @@ -43,15 +51,6 @@ P = ParamSpec("P") T = TypeVar("T") -_registry: Optional[Registry] = None - - -def default_registry(): - global _registry - if not _registry: - _registry = Registry() - return _registry - @overload def function(func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... diff --git a/src/dispatch/experimental/durable/function.py b/src/dispatch/experimental/durable/function.py index 87ccca4..a014935 100644 --- a/src/dispatch/experimental/durable/function.py +++ b/src/dispatch/experimental/durable/function.py @@ -23,7 +23,12 @@ ) from . import frame as ext -from .registry import RegisteredFunction, lookup_function, register_function +from .registry import ( + RegisteredFunction, + lookup_function, + register_function, + unregister_function, +) TRACE = os.getenv("DISPATCH_TRACE", False) @@ -58,6 +63,9 @@ def __call__(self, *args, **kwargs): def __repr__(self) -> str: return f"DurableFunction({self.__qualname__})" + def unregister(self): + unregister_function(self.registered_fn.key) + def durable(fn: Callable) -> Callable: """Returns a "durable" function that creates serializable diff --git a/src/dispatch/experimental/durable/registry.py b/src/dispatch/experimental/durable/registry.py index 6ddac07..9250ec0 100644 --- a/src/dispatch/experimental/durable/registry.py +++ b/src/dispatch/experimental/durable/registry.py @@ -106,6 +106,18 @@ def lookup_function(key: str) -> RegisteredFunction: return _REGISTRY[key] +def unregister_function(key: str): + """Unregister a function by key. + + Args: + key: Unique identifier for the function. + + Raises: + KeyError: A function has not been registered with this key. + """ + del _REGISTRY[key] + + def clear_functions(): """Clear functions clears the registry.""" _REGISTRY.clear() diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 6aeeaca..2b09805 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -18,15 +18,16 @@ def handler(event, context): dispatch.handle(event, context, entrypoint="entrypoint") """ +import asyncio import base64 import json import logging -from typing import Optional +from typing import Optional, Union from awslambdaric.lambda_context import LambdaContext -from dispatch.asyncio import Runner from dispatch.function import Registry +from dispatch.http import FunctionService from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.status import Status @@ -34,27 +35,15 @@ def handler(event, context): logger = logging.getLogger(__name__) -class Dispatch(Registry): +class Dispatch(FunctionService): def __init__( self, - api_key: Optional[str] = None, - api_url: Optional[str] = None, + registry: Optional[Registry] = None, ): - """Initializes a Dispatch Lambda handler. - - Args: - api_key: Dispatch API key to use for authentication. Uses the value - of the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - - """ - + """Initializes a Dispatch Lambda handler.""" # We use a fake endpoint to initialize the base class. The actual endpoint (the Lambda ARN) # is only known when the handler is invoked. - super().__init__(endpoint="http://lambda", api_key=api_key, api_url=api_url) + super().__init__(registry) def handle( self, event: str, context: LambdaContext, entrypoint: Optional[str] = None @@ -63,7 +52,8 @@ def handle( # We override the endpoint of all registered functions before any execution. if context.invoked_function_arn: self.endpoint = context.invoked_function_arn - self.override_endpoint(self.endpoint) + # TODO: this might mutate the default registry, we should figure out a better way. + self.registry.endpoint = self.endpoint if not event: raise ValueError("event is required") @@ -87,14 +77,13 @@ def handle( ) try: - func = self.functions[req.function] + func = self.registry.functions[req.function] except KeyError: raise ValueError(f"function {req.function} not found") input = Input(req) try: - with Runner() as runner: - output = runner.run(func._primitive_call(input)) + output = asyncio.run(func._primitive_call(input)) except Exception: logger.error("function '%s' fatal error", req.function, exc_info=True) raise # FIXME diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 660ebf5..4b0ff9c 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -25,26 +25,20 @@ def read_root(): import fastapi.responses from dispatch.function import Registry -from dispatch.http import ( - FunctionServiceError, - function_service_run, - validate_content_length, -) +from dispatch.http import FunctionService, FunctionServiceError, validate_content_length from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) -class Dispatch(Registry): +class Dispatch(FunctionService): """A Dispatch instance, powered by FastAPI.""" def __init__( self, app: fastapi.FastAPI, - endpoint: Optional[str] = None, + registry: Optional[Registry] = None, verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - api_key: Optional[str] = None, - api_url: Optional[str] = None, ): """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. @@ -53,9 +47,8 @@ def __init__( Args: app: The FastAPI app to configure. - endpoint: Full URL of the application the Dispatch instance will - be running on. Uses the value of the DISPATCH_ENDPOINT_URL - environment variable by default. + registry: A registry of functions to expose. If omitted, the default + registry is used. verification_key: Key to use when verifying signed requests. Uses the value of the DISPATCH_VERIFICATION_KEY environment variable @@ -64,13 +57,6 @@ def __init__( If not set, request signature verification is disabled (a warning will be logged by the constructor). - api_key: Dispatch API key to use for authentication. Uses the value of - the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - Raises: ValueError: If any of the required arguments are missing. """ @@ -78,49 +64,42 @@ def __init__( raise ValueError( "missing FastAPI app as first argument of the Dispatch constructor" ) - super().__init__(endpoint, api_key=api_key, api_url=api_url) - verification_key = parse_verification_key(verification_key, endpoint=endpoint) - function_service = _new_app(self, verification_key) - app.mount("/dispatch.sdk.v1.FunctionService", function_service) - - -def _new_app(function_registry: Registry, verification_key: Optional[Ed25519PublicKey]): - app = fastapi.FastAPI() - - @app.exception_handler(FunctionServiceError) - async def on_error(request: fastapi.Request, exc: FunctionServiceError): - # https://connectrpc.com/docs/protocol/#error-end-stream - return fastapi.responses.JSONResponse( - status_code=exc.status, content={"code": exc.code, "message": exc.message} - ) + super().__init__(registry, verification_key) + function_service = fastapi.FastAPI() + + @function_service.exception_handler(FunctionServiceError) + async def on_error(request: fastapi.Request, exc: FunctionServiceError): + # https://connectrpc.com/docs/protocol/#error-end-stream + return fastapi.responses.JSONResponse( + status_code=exc.status, + content={"code": exc.code, "message": exc.message}, + ) - @app.post( - # The endpoint for execution is hardcoded at the moment. If the service - # gains more endpoints, this should be turned into a dynamic dispatch - # like the official gRPC server does. - "/Run", - ) - async def execute(request: fastapi.Request): - valid, reason = validate_content_length( - int(request.headers.get("content-length", 0)) - ) - if not valid: - raise FunctionServiceError(400, "invalid_argument", reason) - - # Raw request body bytes are only available through the underlying - # starlette Request object's body method, which returns an awaitable, - # forcing execute() to be async. - data: bytes = await request.body() - - content = await function_service_run( - str(request.url), - request.method, - request.headers, - data, - function_registry, - verification_key, + @function_service.post( + # The endpoint for execution is hardcoded at the moment. If the service + # gains more endpoints, this should be turned into a dynamic dispatch + # like the official gRPC server does. + "/Run", ) + async def execute(request: fastapi.Request): + valid, reason = validate_content_length( + int(request.headers.get("content-length", 0)) + ) + if not valid: + raise FunctionServiceError(400, "invalid_argument", reason) + + # Raw request body bytes are only available through the underlying + # starlette Request object's body method, which returns an awaitable, + # forcing execute() to be async. + data: bytes = await request.body() + + content = await self.run( + str(request.url), + request.method, + request.headers, + await request.body(), + ) - return fastapi.Response(content=content, media_type="application/proto") + return fastapi.Response(content=content, media_type="application/proto") - return app + app.mount("/dispatch.sdk.v1.FunctionService", function_service) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 8cece0e..7991f20 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -27,26 +27,20 @@ def read_root(): from flask import Flask, make_response, request from dispatch.function import Registry -from dispatch.http import ( - FunctionServiceError, - function_service_run, - validate_content_length, -) +from dispatch.http import FunctionService, FunctionServiceError, validate_content_length from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) -class Dispatch(Registry): +class Dispatch(FunctionService): """A Dispatch instance, powered by Flask.""" def __init__( self, app: Flask, - endpoint: Optional[str] = None, + registry: Optional[Registry] = None, verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - api_key: Optional[str] = None, - api_url: Optional[str] = None, ): """Initialize a Dispatch endpoint, and integrate it into a Flask app. @@ -55,9 +49,8 @@ def __init__( Args: app: The Flask app to configure. - endpoint: Full URL of the application the Dispatch instance will - be running on. Uses the value of the DISPATCH_ENDPOINT_URL - environment variable by default. + registry: A registry of functions to expose. If omitted, the default + registry is used. verification_key: Key to use when verifying signed requests. Uses the value of the DISPATCH_VERIFICATION_KEY environment variable @@ -66,13 +59,6 @@ def __init__( If not set, request signature verification is disabled (a warning will be logged by the constructor). - api_key: Dispatch API key to use for authentication. Uses the value of - the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - Raises: ValueError: If any of the required arguments are missing. """ @@ -81,12 +67,7 @@ def __init__( "missing Flask app as first argument of the Dispatch constructor" ) - super().__init__(endpoint, api_key=api_key, api_url=api_url) - - self._verification_key = parse_verification_key( - verification_key, endpoint=endpoint - ) - + super().__init__(registry, verification_key) app.errorhandler(FunctionServiceError)(self._handle_error) app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute) @@ -134,16 +115,12 @@ def _execute(self): if not valid: return {"code": "invalid_argument", "message": reason}, 400 - data: bytes = request.get_data(cache=False) - content = asyncio.run( - function_service_run( + self.run( request.url, request.method, dict(request.headers), - data, - self, - self._verification_key, + request.get_data(cache=False), ) ) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 5ad6647..8791297 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -37,6 +37,9 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +T = TypeVar("T") + class GlobalSession(aiohttp.ClientSession): async def __aexit__(self, *args): @@ -62,41 +65,40 @@ def current_session() -> aiohttp.ClientSession: class PrimitiveFunction: - __slots__ = ("_endpoint", "_client", "_name", "_primitive_func") - _endpoint: str - _client: Client + __slots__ = ("_registry", "_name", "_primitive_func") + _registry: str _name: str _primitive_function: PrimitiveFunctionType def __init__( self, - endpoint: str, - client: Client, + registry: Registry, name: str, primitive_func: PrimitiveFunctionType, ): - self._endpoint = endpoint - self._client = client + self._registry = registry.name self._name = name self._primitive_func = primitive_func @property def endpoint(self) -> str: - return self._endpoint - - @endpoint.setter - def endpoint(self, value: str): - self._endpoint = value + return self.registry.endpoint @property def name(self) -> str: return self._name + @property + def registry(self) -> Registry: + return lookup_registry(self._registry) + async def _primitive_call(self, input: Input) -> Output: return await self._primitive_func(input) async def _primitive_dispatch(self, input: Any = None) -> DispatchID: - [dispatch_id] = await self._client.dispatch([self._build_primitive_call(input)]) + [dispatch_id] = await self.registry.client.dispatch( + [self._build_primitive_call(input)] + ) return dispatch_id def _build_primitive_call( @@ -110,10 +112,6 @@ def _build_primitive_call( ) -P = ParamSpec("P") -T = TypeVar("T") - - class Function(PrimitiveFunction, Generic[P, T]): """Callable wrapper around a function meant to be used throughout the Dispatch Python SDK. @@ -123,12 +121,11 @@ class Function(PrimitiveFunction, Generic[P, T]): def __init__( self, - endpoint: str, - client: Client, + registry: Registry, name: str, primitive_func: PrimitiveFunctionType, ): - PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func) + PrimitiveFunction.__init__(self, registry, name, primitive_func) self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable( self._call_async ) @@ -144,9 +141,9 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: call = self.build_call(*args, **kwargs) - [dispatch_id] = await self._client.dispatch([call]) + [dispatch_id] = await self.registry.client.dispatch([call]) - return await self._client.wait(dispatch_id) + return await self.registry.client.wait(dispatch_id) def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """Dispatch an asynchronous call to the function without @@ -189,29 +186,20 @@ def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): class Registry: """Registry of functions.""" - __slots__ = ("functions", "endpoint", "client") + __slots__ = ("functions", "client", "_name", "_endpoint") def __init__( - self, - endpoint: Optional[str] = None, - api_key: Optional[str] = None, - api_url: Optional[str] = None, + self, name: str, client: Optional[Client] = None, endpoint: Optional[str] = None ): """Initialize a function registry. Args: - endpoint: URL of the endpoint that the function is accessible from. - Uses the value of the DISPATCH_ENDPOINT_URL environment variable - by default. + name: A unique name for the registry. - api_key: Dispatch API key to use for authentication when - dispatching calls to functions. Uses the value of the - DISPATCH_API_KEY environment variable by default. + endpoint: URL of the endpoint that the function is accessible from. - api_url: The URL of the Dispatch API to use when dispatching calls - to functions. Uses the value of the DISPATCH_API_URL environment - variable if set, otherwise defaults to the public Dispatch API - (DEFAULT_API_URL). + client: Client instance to use for dispatching calls to registered + functions. Defaults to creating a new client instance. Raises: ValueError: If any of the required arguments are missing. @@ -224,15 +212,45 @@ def __init__( raise ValueError( "missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable" ) - parsed_url = urlparse(endpoint) - if not parsed_url.netloc or not parsed_url.scheme: - raise ValueError( - f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)" - ) logger.info("configuring Dispatch endpoint %s", endpoint) self.functions: Dict[str, PrimitiveFunction] = {} + self.client = client or Client() self.endpoint = endpoint - self.client = Client(api_key=api_key, api_url=api_url) + + if not name: + raise ValueError("missing registry name") + if name in _registries: + raise ValueError(f"registry with name '{name}' already exists") + self._name = name + _registries[name] = self + + def close(self): + """Closes the registry, removing it and all its functions from the + dispatch application.""" + name = self._name + if name: + self._name = "" + del _registries[name] + # TODO: remove registered functions + + @property + def name(self) -> str: + return self._name + + @property + def endpoint(self) -> str: + return self._endpoint + + @endpoint.setter + def endpoint(self, value: str): + parsed = urlparse(value) + if parsed.scheme not in ("http", "https"): + raise ValueError( + f"missing protocol scheme in registry endpoint URL: {value}" + ) + if not parsed.hostname: + raise ValueError(f"missing host in registry endpoint URL: {value}") + self._endpoint = value @overload def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... @@ -276,8 +294,7 @@ async def primitive_func(input: Input) -> Output: durable_primitive_func = durable(primitive_func) wrapped_func = Function[P, T]( - self.endpoint, - self.client, + self, name, durable_primitive_func, ) @@ -290,12 +307,7 @@ def primitive_function( """Decorator that registers primitive functions.""" name = primitive_func.__qualname__ logger.info("registering primitive function: %s", name) - wrapped_func = PrimitiveFunction( - self.endpoint, - self.client, - name, - primitive_func, - ) + wrapped_func = PrimitiveFunction(self, name, primitive_func) self._register(name, wrapped_func) return wrapped_func @@ -310,16 +322,50 @@ def batch(self): # -> Batch: # return self.client.batch() raise NotImplemented - def set_client(self, client: Client): - """Set the Client instance used to dispatch calls to registered functions.""" - # TODO: figure out a way to remove this method, it's only used in examples - self.client = client - for fn in self.functions.values(): - fn._client = client - def override_endpoint(self, endpoint: str): - for fn in self.functions.values(): - fn.endpoint = endpoint +_registries: Dict[str, Registry] = {} + +DEFAULT_REGISTRY_NAME: str = "default" +DEFAULT_REGISTRY: Optional[Registry] = None +"""The default registry for dispatch functions, used by dispatch applications +when no custom registry is provided. + +In most cases, applications do not need to create a custom registry, so this +one would be used by default. + +The default registry use DISPATCH_* environment variables for configuration, +or is uninitialized if they are not set. +""" + + +def default_registry() -> Registry: + """Returns the default registry for dispatch functions. + + The function initializes the default registry if it has not been initialized + yet, using the DISPATCH_* environment variables for configuration. + + Returns: + Registry: The default registry. + + Raises: + ValueError: If the DISPATCH_API_KEY or DISPATCH_ENDPOINT_URL environment + variables are missing. + """ + global DEFAULT_REGISTRY + if DEFAULT_REGISTRY is None: + DEFAULT_REGISTRY = Registry(DEFAULT_REGISTRY_NAME) + return DEFAULT_REGISTRY + + +def lookup_registry(name: str) -> Registry: + return default_registry() if name == DEFAULT_REGISTRY_NAME else _registries[name] + + +def set_default_registry(reg: Registry): + global DEFAULT_REGISTRY + global DEFAULT_REGISTRY_NAME + DEFAULT_REGISTRY = reg + DEFAULT_REGISTRY_NAME = reg.name class Client: diff --git a/src/dispatch/http.py b/src/dispatch/http.py index ccbcfe9..edbc9ca 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -5,12 +5,25 @@ import os from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Coroutine, + Iterable, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, + overload, +) from aiohttp import web from http_message_signatures import InvalidSignature +from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Registry +from dispatch.function import Batch, Function, Registry, default_registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -24,6 +37,60 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +T = TypeVar("T") + + +class FunctionService: + """FunctionService is an abstract class intended to be inherited by objects + that integrate dispatch with other server application frameworks. + + An application encapsulates a function Registry, and implements the API + common to all dispatch integrations. + """ + + def __init__( + self, + registry: Optional[Registry] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + ): + self._registry = registry + self._verification_key = parse_verification_key( + verification_key, + endpoint=self.registry.endpoint, + ) + + @property + def registry(self) -> Registry: + return self._registry or default_registry() + + @property + def verification_key(self) -> Optional[Ed25519PublicKey]: + return self._verification_key + + @overload + def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... + + @overload + def function(self, func: Callable[P, T]) -> Function[P, T]: ... + + def function(self, func): + """Decorator that registers functions.""" + return self.registry.function(func) + + def batch(self) -> Batch: + return self.registry.batch() + + async def run(self, url, method, headers, data): + return await function_service_run( + url, + method, + headers, + data, + self.registry, + self.verification_key, + ) + class FunctionServiceError(Exception): __slots__ = ("status", "code", "message") @@ -44,7 +111,7 @@ def validate_content_length(content_length: int) -> Tuple[bool, str]: return True, "" -class FunctionService(BaseHTTPRequestHandler): +class FunctionServiceHTTPRequestHandler(BaseHTTPRequestHandler): def __init__( self, @@ -148,7 +215,7 @@ def __init__( ) def __call__(self, request, client_address, server): - return FunctionService( + return FunctionServiceHTTPRequestHandler( request, client_address, server, diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 07c306c..d2cbfd5 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -13,9 +13,15 @@ import dispatch.experimental.durable.registry from dispatch.function import Client as BaseClient -from dispatch.function import ClientError, Input, Output -from dispatch.function import Registry as BaseRegistry -from dispatch.http import Dispatch +from dispatch.function import ( + ClientError, + Input, + Output, + Registry, + default_registry, + set_default_registry, +) +from dispatch.http import Dispatch, FunctionService from dispatch.http import Server as BaseServer from dispatch.sdk.v1.call_pb2 import Call, CallResult from dispatch.sdk.v1.dispatch_pb2 import DispatchRequest, DispatchResponse @@ -46,7 +52,6 @@ ] P = ParamSpec("P") -R = TypeVar("R", bound=BaseRegistry) T = TypeVar("T") DISPATCH_ENDPOINT_URL = "http://127.0.0.1:0" @@ -62,17 +67,6 @@ def session(self) -> aiohttp.ClientSession: return aiohttp.ClientSession() -class Registry(BaseRegistry): - def __init__(self): - # placeholder values to initialize the base class prior to binding - # random ports. - super().__init__( - endpoint=DISPATCH_ENDPOINT_URL, - api_url=DISPATCH_API_URL, - api_key=DISPATCH_API_KEY, - ) - - class Server(BaseServer): def __init__(self, app: web.Application): super().__init__("127.0.0.1", 0, app) @@ -258,7 +252,8 @@ def session(self) -> aiohttp.ClientSession: return self._session -async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: +async def main(coro: Coroutine[Any, Any, None]) -> None: + reg = default_registry() api = Service() app = Dispatch(reg) try: @@ -268,7 +263,7 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: # ideal but it works for now. reg.client.api_url.value = backend.url reg.endpoint = server.url - await fn(reg) + await coro finally: await api.close() # TODO: let's figure out how to get rid of this global registry @@ -276,8 +271,8 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: dispatch.experimental.durable.registry.clear_functions() -def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: - return asyncio.run(main(reg, fn)) +def run(coro: Coroutine[Any, Any, None]) -> None: + return asyncio.run(main(coro)) # TODO: these decorators still need work, until we figure out serialization @@ -297,20 +292,18 @@ def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: # (WIP) -def function(fn: Callable[[Registry], Coroutine[Any, Any, None]]) -> Callable[[], None]: +def function(fn: Callable[[], Coroutine[Any, Any, None]]) -> Callable[[], None]: @wraps(fn) def wrapper(): - return run(Registry(), fn) + return run(fn()) return wrapper -def method( - fn: Callable[[T, Registry], Coroutine[Any, Any, None]] -) -> Callable[[T], None]: +def method(fn: Callable[[T], Coroutine[Any, Any, None]]) -> Callable[[T], None]: @wraps(fn) def wrapper(self: T): - return run(Registry(), lambda reg: fn(self, reg)) + return run(fn(self)) return wrapper @@ -332,6 +325,41 @@ def test(self): return test +_registry = Registry( + name=__name__, + endpoint=DISPATCH_ENDPOINT_URL, + client=Client(api_key=DISPATCH_API_KEY, api_url=DISPATCH_API_URL), +) + + +@_registry.function +def greet() -> str: + return "Hello World!" + + +@_registry.function +def greet_name(name: str) -> str: + return f"Hello world: {name}" + + +@_registry.function +def echo(name: str) -> str: + return name + + +@_registry.function +def length(name: str) -> int: + return len(name) + + +@_registry.function +def broken() -> str: + raise ValueError("something went wrong!") + + +set_default_registry(_registry) + + class TestCase(unittest.TestCase): """TestCase implements the generic test suite used in dispatch-py to test various integrations of the SDK with frameworks like FastAPI, Flask, etc... @@ -345,11 +373,7 @@ class TestCase(unittest.TestCase): more details). """ - def dispatch_test_init(self, api_key: str, api_url: str) -> BaseRegistry: - """Called to initialize each test case. The method returns the dispatch - function registry which can be used to register function instances - during tests. - """ + def dispatch_test_init(self, reg: Registry) -> str: raise NotImplementedError def dispatch_test_run(self): @@ -360,18 +384,14 @@ def dispatch_test_stop(self): def setUp(self): self.service = Service() + self.server = Server(self.service) self.server_loop = asyncio.new_event_loop() self.server_loop.run_until_complete(self.server.start()) - self.dispatch = self.dispatch_test_init( - api_key=DISPATCH_API_KEY, api_url=self.server.url - ) - self.dispatch.client = Client( - api_key=self.dispatch.client.api_key.value, - api_url=self.dispatch.client.api_url.value, - ) - + _registry.client.api_key.value = DISPATCH_API_KEY + _registry.client.api_url.value = self.server.url + _registry.endpoint = self.dispatch_test_init(_registry) self.client_thread = threading.Thread(target=self.dispatch_test_run) self.client_thread.start() @@ -385,20 +405,16 @@ def tearDown(self): # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. - dispatch.experimental.durable.registry.clear_functions() + # + # We can't erase the registry because user tests might have registered + # functions in the global scope that would be lost after the first test + # we run. + # + # dispatch.experimental.durable.registry.clear_functions() @property def function_service_run_url(self) -> str: - return f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run" - - def test_register_duplicate_functions(self): - @self.dispatch.function - def my_function(): ... - - with self.assertRaises(ValueError): - - @self.dispatch.function - def my_function(): ... + return f"{_registry.endpoint}/dispatch.sdk.v1.FunctionService/Run" @aiotest async def test_content_length_missing(self): @@ -442,41 +458,21 @@ async def test_call_function_missing(self): @aiotest async def test_call_function_no_input(self): - @self.dispatch.function - def my_function() -> str: - return "Hello World!" - - ret = await my_function() + ret = await greet() self.assertEqual(ret, "Hello World!") @aiotest async def test_call_function_with_input(self): - @self.dispatch.function - def my_function(name: str) -> str: - return f"Hello world: {name}" - - ret = await my_function("52") + ret = await greet_name("52") self.assertEqual(ret, "Hello world: 52") @aiotest async def test_call_function_raise_error(self): - @self.dispatch.function - def my_function(name: str) -> str: - raise ValueError("something went wrong!") - with self.assertRaises(ValueError) as e: - await my_function("52") + await broken() @aiotest async def test_call_two_functions(self): - @self.dispatch.function - def echo(name: str) -> str: - return name - - @self.dispatch.function - def length(name: str) -> int: - return len(name) - self.assertEqual(await echo("hello"), "hello") self.assertEqual(await length("hello"), 5) diff --git a/tests/dispatch/test_function.py b/tests/dispatch/test_function.py index 3550b4b..276a458 100644 --- a/tests/dispatch/test_function.py +++ b/tests/dispatch/test_function.py @@ -1,10 +1,18 @@ import pickle -from dispatch.test import Registry +from dispatch.function import Client, Registry +from dispatch.test import DISPATCH_API_KEY, DISPATCH_API_URL, DISPATCH_ENDPOINT_URL def test_serializable(): - reg = Registry() + reg = Registry( + name=__name__, + endpoint=DISPATCH_ENDPOINT_URL, + client=Client( + api_key=DISPATCH_API_KEY, + api_url=DISPATCH_API_URL, + ), + ) @reg.function def my_function(): @@ -12,3 +20,4 @@ def my_function(): s = pickle.dumps(my_function) pickle.loads(s) + reg.close() diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 4830e37..f7634c0 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -24,7 +24,15 @@ from dispatch.asyncio import Runner from dispatch.experimental.durable.registry import clear_functions from dispatch.fastapi import Dispatch -from dispatch.function import Arguments, Error, Function, Input, Output +from dispatch.function import ( + Arguments, + Client, + Error, + Function, + Input, + Output, + Registry, +) from dispatch.proto import _any_unpickle as any_unpickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb @@ -40,7 +48,7 @@ class TestFastAPI(dispatch.test.TestCase): - def dispatch_test_init(self, api_key: str, api_url: str) -> Dispatch: + def dispatch_test_init(self, reg: Registry) -> str: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("127.0.0.1", 0)) sock.listen(128) @@ -48,16 +56,14 @@ def dispatch_test_init(self, api_key: str, api_url: str) -> Dispatch: (host, port) = sock.getsockname() app = FastAPI() - reg = Dispatch( - app, endpoint=f"http://{host}:{port}", api_key=api_key, api_url=api_url - ) + dispatch = Dispatch(app, registry=reg) config = uvicorn.Config(app, host=host, port=port) self.sockets = [sock] self.uvicorn = uvicorn.Server(config) self.runner = Runner() self.event = asyncio.Event() - return reg + return f"http://{host}:{port}" def dispatch_test_run(self): loop = self.runner.get_loop() @@ -65,6 +71,9 @@ def dispatch_test_run(self): self.runner.run(self.event.wait()) self.runner.close() + for sock in self.sockets: + sock.close() + def dispatch_test_stop(self): loop = self.runner.get_loop() loop.call_soon_threadsafe(self.event.set) @@ -73,9 +82,14 @@ def dispatch_test_stop(self): def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): return Dispatch( app, - endpoint=endpoint, - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", + registry=Registry( + name=__name__, + endpoint=endpoint, + client=Client( + api_key="0000000000000000", + api_url="http://127.0.0.1:10000", + ), + ), ) @@ -89,359 +103,362 @@ def response_output(resp: function_pb.RunResponse) -> Any: return any_unpickle(resp.exit.result.output) -class TestCoroutine(unittest.TestCase): - def setUp(self): - clear_functions() - - self.app = fastapi.FastAPI() - - @self.app.get("/") - def root(): - return "OK" - - self.dispatch = create_dispatch_instance( - self.app, endpoint="https://127.0.0.1:9999" - ) - self.http_client = TestClient(self.app) - self.client = create_endpoint_client(self.app) - - def execute( - self, func: Function, input=None, state=None, calls=None - ) -> function_pb.RunResponse: - """Test helper to invoke coroutines on the local server.""" - req = function_pb.RunRequest(function=func.name) - - if input is not None: - input_bytes = pickle.dumps(input) - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes)) - req.input.CopyFrom(input_any) - if state is not None: - req.poll_result.coroutine_state = state - if calls is not None: - for c in calls: - req.poll_result.results.append(c) - - resp = self.client.run(req) - self.assertIsInstance(resp, function_pb.RunResponse) - return resp - - def call(self, func: Function, *args, **kwargs) -> function_pb.RunResponse: - return self.execute(func, input=Arguments(args, kwargs)) - - def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: - req = function_pb.RunRequest( - function=call.function, - input=call.input, - ) - resp = self.client.run(req) - self.assertIsInstance(resp, function_pb.RunResponse) - - # Assert the response is terminal. Good enough until the test client can - # orchestrate coroutines. - self.assertTrue(len(resp.poll.coroutine_state) == 0) - - resp.exit.result.correlation_id = call.correlation_id - return resp.exit.result - - def test_no_input(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value("Hello World!") - - resp = self.execute(my_function) - - out = response_output(resp) - self.assertEqual(out, "Hello World!") - - def test_missing_coroutine(self): - req = function_pb.RunRequest( - function="does-not-exist", - ) - - with self.assertRaises(httpx.HTTPStatusError) as cm: - self.client.run(req) - self.assertEqual(cm.exception.response.status_code, 404) - - def test_string_input(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value(f"You sent '{input.input}'") - - resp = self.execute(my_function, input="cool stuff") - out = response_output(resp) - self.assertEqual(out, "You sent 'cool stuff'") - - def test_error_on_access_state_in_first_call(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - try: - print(input.coroutine_state) - except ValueError: - return Output.error( - Error.from_exception( - ValueError("This input is for a first function call") - ) - ) - return Output.value("not reached") - - resp = self.execute(my_function, input="cool stuff") - self.assertEqual("ValueError", resp.exit.result.error.type) - self.assertEqual( - "This input is for a first function call", resp.exit.result.error.message - ) - - def test_error_on_access_input_in_second_call(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - if input.is_first_call: - return Output.poll(coroutine_state=b"42") - try: - print(input.input) - except ValueError: - return Output.error( - Error.from_exception( - ValueError("This input is for a resumed coroutine") - ) - ) - return Output.value("not reached") - - resp = self.execute(my_function, input="cool stuff") - self.assertEqual(b"42", resp.poll.coroutine_state) - - resp = self.execute(my_function, state=resp.poll.coroutine_state) - self.assertEqual("ValueError", resp.exit.result.error.type) - self.assertEqual( - "This input is for a resumed coroutine", resp.exit.result.error.message - ) - - def test_duplicate_coro(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value("Do one thing") - - with self.assertRaises(ValueError): - - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value("Do something else") - - def test_two_simple_coroutines(self): - @self.dispatch.primitive_function - async def echoroutine(input: Input) -> Output: - return Output.value(f"Echo: '{input.input}'") - - @self.dispatch.primitive_function - async def len_coroutine(input: Input) -> Output: - return Output.value(f"Length: {len(input.input)}") - - data = "cool stuff" - resp = self.execute(echoroutine, input=data) - out = response_output(resp) - self.assertEqual(out, "Echo: 'cool stuff'") - - resp = self.execute(len_coroutine, input=data) - out = response_output(resp) - self.assertEqual(out, "Length: 10") - - def test_coroutine_with_state(self): - @self.dispatch.primitive_function - async def coroutine3(input: Input) -> Output: - if input.is_first_call: - counter = input.input - else: - (counter,) = struct.unpack("@i", input.coroutine_state) - counter -= 1 - if counter <= 0: - return Output.value("done") - coroutine_state = struct.pack("@i", counter) - return Output.poll(coroutine_state=coroutine_state) - - # first call - resp = self.execute(coroutine3, input=4) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) - - # resume, state = 3 - resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) - - # resume, state = 2 - resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) - - # resume, state = 1 - resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) == 0) - out = response_output(resp) - self.assertEqual(out, "done") - - def test_coroutine_poll(self): - @self.dispatch.primitive_function - async def coro_compute_len(input: Input) -> Output: - return Output.value(len(input.input)) - - @self.dispatch.primitive_function - async def coroutine_main(input: Input) -> Output: - if input.is_first_call: - text: str = input.input - return Output.poll( - coroutine_state=text.encode(), - calls=[coro_compute_len._build_primitive_call(text)], - ) - text = input.coroutine_state.decode() - length = input.call_results[0].output - return Output.value(f"length={length} text='{text}'") - - resp = self.execute(coroutine_main, input="cool stuff") - - # main saved some state - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) - # main asks for 1 call to compute_len - self.assertEqual(len(resp.poll.calls), 1) - call = resp.poll.calls[0] - self.assertEqual(call.function, coro_compute_len.name) - self.assertEqual(any_unpickle(call.input), "cool stuff") - - # make the requested compute_len - resp2 = self.proto_call(call) - # check the result is the terminal expected response - len_resp = any_unpickle(resp2.output) - self.assertEqual(10, len_resp) - - # resume main with the result - resp = self.execute(coroutine_main, state=state, calls=[resp2]) - # validate the final result - self.assertTrue(len(resp.poll.coroutine_state) == 0) - out = response_output(resp) - self.assertEqual("length=10 text='cool stuff'", out) - - def test_coroutine_poll_error(self): - @self.dispatch.primitive_function - async def coro_compute_len(input: Input) -> Output: - return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) - - @self.dispatch.primitive_function - async def coroutine_main(input: Input) -> Output: - if input.is_first_call: - text: str = input.input - return Output.poll( - coroutine_state=text.encode(), - calls=[coro_compute_len._build_primitive_call(text)], - ) - error = input.call_results[0].error - if error is not None: - return Output.value(f"msg={error.message} type='{error.type}'") - else: - raise RuntimeError(f"unexpected call results: {input.call_results}") - - resp = self.execute(coroutine_main, input="cool stuff") - - # main saved some state - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) - # main asks for 1 call to compute_len - self.assertEqual(len(resp.poll.calls), 1) - call = resp.poll.calls[0] - self.assertEqual(call.function, coro_compute_len.name) - self.assertEqual(any_unpickle(call.input), "cool stuff") - - # make the requested compute_len - resp2 = self.proto_call(call) - - # resume main with the result - resp = self.execute(coroutine_main, state=state, calls=[resp2]) - # validate the final result - self.assertTrue(len(resp.poll.coroutine_state) == 0) - out = response_output(resp) - self.assertEqual(out, "msg=Dead type='type'") - - def test_coroutine_error(self): - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) - - resp = self.execute(mycoro) - self.assertEqual("sometype", resp.exit.result.error.type) - self.assertEqual("dead", resp.exit.result.error.message) - - def test_coroutine_expected_exception(self): - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - try: - 1 / 0 - except ZeroDivisionError as e: - return Output.error(Error.from_exception(e)) - self.fail("should not reach here") - - resp = self.execute(mycoro) - self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) - self.assertEqual("division by zero", resp.exit.result.error.message) - self.assertEqual(Status.PERMANENT_ERROR, resp.status) - - def test_coroutine_unexpected_exception(self): - @self.dispatch.function - def mycoro(): - 1 / 0 - self.fail("should not reach here") - - resp = self.call(mycoro) - self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) - self.assertEqual("division by zero", resp.exit.result.error.message) - self.assertEqual(Status.PERMANENT_ERROR, resp.status) - - def test_specific_status(self): - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - return Output.error(Error(Status.THROTTLED, "foo", "bar")) - - resp = self.execute(mycoro) - self.assertEqual("foo", resp.exit.result.error.type) - self.assertEqual("bar", resp.exit.result.error.message) - self.assertEqual(Status.THROTTLED, resp.status) - - def test_tailcall(self): - @self.dispatch.function - def other_coroutine(value: Any) -> str: - return f"Hello {value}" - - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - return Output.tail_call(other_coroutine._build_primitive_call(42)) - - resp = self.call(mycoro) - self.assertEqual(other_coroutine.name, resp.exit.tail_call.function) - self.assertEqual(42, any_unpickle(resp.exit.tail_call.input)) - - def test_library_error_categorization(self): - @self.dispatch.function - def get(path: str) -> httpx.Response: - http_response = self.http_client.get(path) - http_response.raise_for_status() - return http_response - - resp = self.call(get, "/") - self.assertEqual(Status.OK, Status(resp.status)) - http_response = any_unpickle(resp.exit.result.output) - self.assertEqual("application/json", http_response.headers["content-type"]) - self.assertEqual('"OK"', http_response.text) - - resp = self.call(get, "/missing") - self.assertEqual(Status.NOT_FOUND, Status(resp.status)) - - def test_library_output_categorization(self): - @self.dispatch.function - def get(path: str) -> httpx.Response: - http_response = self.http_client.get(path) - http_response.status_code = 429 - return http_response - - resp = self.call(get, "/") - self.assertEqual(Status.THROTTLED, Status(resp.status)) - http_response = any_unpickle(resp.exit.result.output) - self.assertEqual("application/json", http_response.headers["content-type"]) - self.assertEqual('"OK"', http_response.text) +# class TestCoroutine(unittest.TestCase): +# def setUp(self): +# clear_functions() + +# self.app = fastapi.FastAPI() + +# @self.app.get("/") +# def root(): +# return "OK" + +# self.dispatch = create_dispatch_instance( +# self.app, endpoint="https://127.0.0.1:9999" +# ) +# self.http_client = TestClient(self.app) +# self.client = create_endpoint_client(self.app) + +# def tearDown(self): +# self.dispatch.registry.close() + +# def execute( +# self, func: Function, input=None, state=None, calls=None +# ) -> function_pb.RunResponse: +# """Test helper to invoke coroutines on the local server.""" +# req = function_pb.RunRequest(function=func.name) + +# if input is not None: +# input_bytes = pickle.dumps(input) +# input_any = google.protobuf.any_pb2.Any() +# input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes)) +# req.input.CopyFrom(input_any) +# if state is not None: +# req.poll_result.coroutine_state = state +# if calls is not None: +# for c in calls: +# req.poll_result.results.append(c) + +# resp = self.client.run(req) +# self.assertIsInstance(resp, function_pb.RunResponse) +# return resp + +# def call(self, func: Function, *args, **kwargs) -> function_pb.RunResponse: +# return self.execute(func, input=Arguments(args, kwargs)) + +# def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: +# req = function_pb.RunRequest( +# function=call.function, +# input=call.input, +# ) +# resp = self.client.run(req) +# self.assertIsInstance(resp, function_pb.RunResponse) + +# # Assert the response is terminal. Good enough until the test client can +# # orchestrate coroutines. +# self.assertTrue(len(resp.poll.coroutine_state) == 0) + +# resp.exit.result.correlation_id = call.correlation_id +# return resp.exit.result + +# def test_no_input(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value("Hello World!") + +# resp = self.execute(my_function) + +# out = response_output(resp) +# self.assertEqual(out, "Hello World!") + +# def test_missing_coroutine(self): +# req = function_pb.RunRequest( +# function="does-not-exist", +# ) + +# with self.assertRaises(httpx.HTTPStatusError) as cm: +# self.client.run(req) +# self.assertEqual(cm.exception.response.status_code, 404) + +# def test_string_input(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value(f"You sent '{input.input}'") + +# resp = self.execute(my_function, input="cool stuff") +# out = response_output(resp) +# self.assertEqual(out, "You sent 'cool stuff'") + +# def test_error_on_access_state_in_first_call(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# try: +# print(input.coroutine_state) +# except ValueError: +# return Output.error( +# Error.from_exception( +# ValueError("This input is for a first function call") +# ) +# ) +# return Output.value("not reached") + +# resp = self.execute(my_function, input="cool stuff") +# self.assertEqual("ValueError", resp.exit.result.error.type) +# self.assertEqual( +# "This input is for a first function call", resp.exit.result.error.message +# ) + +# def test_error_on_access_input_in_second_call(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# if input.is_first_call: +# return Output.poll(coroutine_state=b"42") +# try: +# print(input.input) +# except ValueError: +# return Output.error( +# Error.from_exception( +# ValueError("This input is for a resumed coroutine") +# ) +# ) +# return Output.value("not reached") + +# resp = self.execute(my_function, input="cool stuff") +# self.assertEqual(b"42", resp.poll.coroutine_state) + +# resp = self.execute(my_function, state=resp.poll.coroutine_state) +# self.assertEqual("ValueError", resp.exit.result.error.type) +# self.assertEqual( +# "This input is for a resumed coroutine", resp.exit.result.error.message +# ) + +# def test_duplicate_coro(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value("Do one thing") + +# with self.assertRaises(ValueError): + +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value("Do something else") + +# def test_two_simple_coroutines(self): +# @self.dispatch.primitive_function +# async def echoroutine(input: Input) -> Output: +# return Output.value(f"Echo: '{input.input}'") + +# @self.dispatch.primitive_function +# async def len_coroutine(input: Input) -> Output: +# return Output.value(f"Length: {len(input.input)}") + +# data = "cool stuff" +# resp = self.execute(echoroutine, input=data) +# out = response_output(resp) +# self.assertEqual(out, "Echo: 'cool stuff'") + +# resp = self.execute(len_coroutine, input=data) +# out = response_output(resp) +# self.assertEqual(out, "Length: 10") + +# def test_coroutine_with_state(self): +# @self.dispatch.primitive_function +# async def coroutine3(input: Input) -> Output: +# if input.is_first_call: +# counter = input.input +# else: +# (counter,) = struct.unpack("@i", input.coroutine_state) +# counter -= 1 +# if counter <= 0: +# return Output.value("done") +# coroutine_state = struct.pack("@i", counter) +# return Output.poll(coroutine_state=coroutine_state) + +# # first call +# resp = self.execute(coroutine3, input=4) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) + +# # resume, state = 3 +# resp = self.execute(coroutine3, state=state) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) + +# # resume, state = 2 +# resp = self.execute(coroutine3, state=state) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) + +# # resume, state = 1 +# resp = self.execute(coroutine3, state=state) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) == 0) +# out = response_output(resp) +# self.assertEqual(out, "done") + +# def test_coroutine_poll(self): +# @self.dispatch.primitive_function +# async def coro_compute_len(input: Input) -> Output: +# return Output.value(len(input.input)) + +# @self.dispatch.primitive_function +# async def coroutine_main(input: Input) -> Output: +# if input.is_first_call: +# text: str = input.input +# return Output.poll( +# coroutine_state=text.encode(), +# calls=[coro_compute_len._build_primitive_call(text)], +# ) +# text = input.coroutine_state.decode() +# length = input.call_results[0].output +# return Output.value(f"length={length} text='{text}'") + +# resp = self.execute(coroutine_main, input="cool stuff") + +# # main saved some state +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) +# # main asks for 1 call to compute_len +# self.assertEqual(len(resp.poll.calls), 1) +# call = resp.poll.calls[0] +# self.assertEqual(call.function, coro_compute_len.name) +# self.assertEqual(any_unpickle(call.input), "cool stuff") + +# # make the requested compute_len +# resp2 = self.proto_call(call) +# # check the result is the terminal expected response +# len_resp = any_unpickle(resp2.output) +# self.assertEqual(10, len_resp) + +# # resume main with the result +# resp = self.execute(coroutine_main, state=state, calls=[resp2]) +# # validate the final result +# self.assertTrue(len(resp.poll.coroutine_state) == 0) +# out = response_output(resp) +# self.assertEqual("length=10 text='cool stuff'", out) + +# def test_coroutine_poll_error(self): +# @self.dispatch.primitive_function +# async def coro_compute_len(input: Input) -> Output: +# return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) + +# @self.dispatch.primitive_function +# async def coroutine_main(input: Input) -> Output: +# if input.is_first_call: +# text: str = input.input +# return Output.poll( +# coroutine_state=text.encode(), +# calls=[coro_compute_len._build_primitive_call(text)], +# ) +# error = input.call_results[0].error +# if error is not None: +# return Output.value(f"msg={error.message} type='{error.type}'") +# else: +# raise RuntimeError(f"unexpected call results: {input.call_results}") + +# resp = self.execute(coroutine_main, input="cool stuff") + +# # main saved some state +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) +# # main asks for 1 call to compute_len +# self.assertEqual(len(resp.poll.calls), 1) +# call = resp.poll.calls[0] +# self.assertEqual(call.function, coro_compute_len.name) +# self.assertEqual(any_unpickle(call.input), "cool stuff") + +# # make the requested compute_len +# resp2 = self.proto_call(call) + +# # resume main with the result +# resp = self.execute(coroutine_main, state=state, calls=[resp2]) +# # validate the final result +# self.assertTrue(len(resp.poll.coroutine_state) == 0) +# out = response_output(resp) +# self.assertEqual(out, "msg=Dead type='type'") + +# def test_coroutine_error(self): +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) + +# resp = self.execute(mycoro) +# self.assertEqual("sometype", resp.exit.result.error.type) +# self.assertEqual("dead", resp.exit.result.error.message) + +# def test_coroutine_expected_exception(self): +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# try: +# 1 / 0 +# except ZeroDivisionError as e: +# return Output.error(Error.from_exception(e)) +# self.fail("should not reach here") + +# resp = self.execute(mycoro) +# self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) +# self.assertEqual("division by zero", resp.exit.result.error.message) +# self.assertEqual(Status.PERMANENT_ERROR, resp.status) + +# def test_coroutine_unexpected_exception(self): +# @self.dispatch.function +# def mycoro(): +# 1 / 0 +# self.fail("should not reach here") + +# resp = self.call(mycoro) +# self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) +# self.assertEqual("division by zero", resp.exit.result.error.message) +# self.assertEqual(Status.PERMANENT_ERROR, resp.status) + +# def test_specific_status(self): +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# return Output.error(Error(Status.THROTTLED, "foo", "bar")) + +# resp = self.execute(mycoro) +# self.assertEqual("foo", resp.exit.result.error.type) +# self.assertEqual("bar", resp.exit.result.error.message) +# self.assertEqual(Status.THROTTLED, resp.status) + +# def test_tailcall(self): +# @self.dispatch.function +# def other_coroutine(value: Any) -> str: +# return f"Hello {value}" + +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# return Output.tail_call(other_coroutine._build_primitive_call(42)) + +# resp = self.call(mycoro) +# self.assertEqual(other_coroutine.name, resp.exit.tail_call.function) +# self.assertEqual(42, any_unpickle(resp.exit.tail_call.input)) + +# def test_library_error_categorization(self): +# @self.dispatch.function +# def get(path: str) -> httpx.Response: +# http_response = self.http_client.get(path) +# http_response.raise_for_status() +# return http_response + +# resp = self.call(get, "/") +# self.assertEqual(Status.OK, Status(resp.status)) +# http_response = any_unpickle(resp.exit.result.output) +# self.assertEqual("application/json", http_response.headers["content-type"]) +# self.assertEqual('"OK"', http_response.text) + +# resp = self.call(get, "/missing") +# self.assertEqual(Status.NOT_FOUND, Status(resp.status)) + +# def test_library_output_categorization(self): +# @self.dispatch.function +# def get(path: str) -> httpx.Response: +# http_response = self.http_client.get(path) +# http_response.status_code = 429 +# return http_response + +# resp = self.call(get, "/") +# self.assertEqual(Status.THROTTLED, Status(resp.status)) +# http_response = any_unpickle(resp.exit.result.output) +# self.assertEqual("application/json", http_response.headers["content-type"]) +# self.assertEqual('"OK"', http_response.text) diff --git a/tests/test_flask.py b/tests/test_flask.py index 3a1a500..494199b 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -10,18 +10,15 @@ class TestFlask(dispatch.test.TestCase): - def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: + def dispatch_test_init(self, reg: Registry) -> str: host = "127.0.0.1" port = 56789 app = Flask("test") - reg = Dispatch( - app, endpoint=f"http://{host}:{port}", api_key=api_key, api_url=api_url - ) + dispatch = Dispatch(app, registry=reg) self.wsgi = make_server(host, port, app) - # self.flask_registry = reg - return reg + return f"http://{host}:{port}" def dispatch_test_run(self): self.wsgi.serve_forever(poll_interval=0.05) @@ -29,4 +26,3 @@ def dispatch_test_run(self): def dispatch_test_stop(self): self.wsgi.shutdown() self.wsgi.server_close() - # self.flask_registry.close() diff --git a/tests/test_http.py b/tests/test_http.py index cc25872..304617b 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -6,12 +6,12 @@ import dispatch.test from dispatch.asyncio import Runner from dispatch.function import Registry -from dispatch.http import Dispatch, Server +from dispatch.http import Dispatch, FunctionService, Server class TestHTTP(dispatch.test.TestCase): - def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: + def dispatch_test_init(self, reg: Registry) -> str: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("127.0.0.1", 0)) @@ -19,19 +19,13 @@ def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: (host, port) = sock.getsockname() - reg = Registry( - endpoint=f"http://{host}:{port}", - api_key=api_key, - api_url=api_url, - ) - self.httpserver = HTTPServer( server_address=(host, port), RequestHandlerClass=Dispatch(reg), bind_and_activate=False, ) self.httpserver.socket = sock - return reg + return f"http://{host}:{port}" def dispatch_test_run(self): self.httpserver.serve_forever(poll_interval=0.05) @@ -39,27 +33,21 @@ def dispatch_test_run(self): def dispatch_test_stop(self): self.httpserver.shutdown() self.httpserver.server_close() + self.httpserver.socket.close() class TestAIOHTTP(dispatch.test.TestCase): - def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: + def dispatch_test_init(self, reg: Registry) -> str: host = "127.0.0.1" port = 0 - reg = Registry( - endpoint=f"http://{host}:{port}", - api_key=api_key, - api_url=api_url, - ) - self.aiowait = asyncio.Event() self.aioloop = Runner() self.aiohttp = Server(host, port, Dispatch(reg)) self.aioloop.run(self.aiohttp.start()) - reg.endpoint = f"http://{self.aiohttp.host}:{self.aiohttp.port}" - return reg + return f"http://{self.aiohttp.host}:{self.aiohttp.port}" def dispatch_test_run(self): self.aioloop.run(self.aiowait.wait())