Skip to content

Commit

Permalink
refactor: use composition, default registry, function service
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
  • Loading branch information
achille-roussel committed Jun 14, 2024
1 parent d855b3d commit f1337e2
Show file tree
Hide file tree
Showing 16 changed files with 739 additions and 656 deletions.
2 changes: 1 addition & 1 deletion examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_app(self):
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))
dispatch.registry.client = Client(api_url=dispatch_server.url)

response = app_client.get("/")
self.assertEqual(response.status_code, 200)
Expand Down
2 changes: 1 addition & 1 deletion examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_app(self):
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))
dispatch.registry.client = Client(api_url=dispatch_server.url)

response = app_client.get("/")
self.assertEqual(response.status_code, 200)
Expand Down
2 changes: 1 addition & 1 deletion examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_app(self):
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))
dispatch.registry.client = Client(api_url=dispatch_server.url)

response = app_client.get("/")
self.assertEqual(response.status_code, 200)
Expand Down
19 changes: 9 additions & 10 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@

import dispatch.integrations
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import Batch, Client, ClientError, Function, Registry, Reset
from dispatch.function import (
Batch,
Client,
ClientError,
Function,
Registry,
Reset,
default_registry,
)
from dispatch.http import Dispatch
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
Expand Down Expand Up @@ -43,15 +51,6 @@
P = ParamSpec("P")
T = TypeVar("T")

_registry: Optional[Registry] = None


def default_registry():
global _registry
if not _registry:
_registry = Registry()
return _registry


@overload
def function(func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
Expand Down
10 changes: 9 additions & 1 deletion src/dispatch/experimental/durable/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
)

from . import frame as ext
from .registry import RegisteredFunction, lookup_function, register_function
from .registry import (
RegisteredFunction,
lookup_function,
register_function,
unregister_function,
)

TRACE = os.getenv("DISPATCH_TRACE", False)

Expand Down Expand Up @@ -58,6 +63,9 @@ def __call__(self, *args, **kwargs):
def __repr__(self) -> str:
return f"DurableFunction({self.__qualname__})"

def unregister(self):
unregister_function(self.registered_fn.key)


def durable(fn: Callable) -> Callable:
"""Returns a "durable" function that creates serializable
Expand Down
12 changes: 12 additions & 0 deletions src/dispatch/experimental/durable/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ def lookup_function(key: str) -> RegisteredFunction:
return _REGISTRY[key]


def unregister_function(key: str):
"""Unregister a function by key.
Args:
key: Unique identifier for the function.
Raises:
KeyError: A function has not been registered with this key.
"""
del _REGISTRY[key]


def clear_functions():
"""Clear functions clears the registry."""
_REGISTRY.clear()
33 changes: 11 additions & 22 deletions src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,32 @@ def handler(event, context):
dispatch.handle(event, context, entrypoint="entrypoint")
"""

import asyncio
import base64
import json
import logging
from typing import Optional
from typing import Optional, Union

from awslambdaric.lambda_context import LambdaContext

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.http import FunctionService
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.status import Status

logger = logging.getLogger(__name__)


class Dispatch(Registry):
class Dispatch(FunctionService):
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
registry: Optional[Registry] = None,
):
"""Initializes a Dispatch Lambda handler.
Args:
api_key: Dispatch API key to use for authentication. Uses the value
of the DISPATCH_API_KEY environment variable by default.
api_url: The URL of the Dispatch API to use. Uses the value of the
DISPATCH_API_URL environment variable if set, otherwise
defaults to the public Dispatch API (DEFAULT_API_URL).
"""

"""Initializes a Dispatch Lambda handler."""
# We use a fake endpoint to initialize the base class. The actual endpoint (the Lambda ARN)
# is only known when the handler is invoked.
super().__init__(endpoint="http://lambda", api_key=api_key, api_url=api_url)
super().__init__(registry)

def handle(
self, event: str, context: LambdaContext, entrypoint: Optional[str] = None
Expand All @@ -63,7 +52,8 @@ def handle(
# We override the endpoint of all registered functions before any execution.
if context.invoked_function_arn:
self.endpoint = context.invoked_function_arn
self.override_endpoint(self.endpoint)
# TODO: this might mutate the default registry, we should figure out a better way.
self.registry.endpoint = self.endpoint

if not event:
raise ValueError("event is required")
Expand All @@ -87,14 +77,13 @@ def handle(
)

try:
func = self.functions[req.function]
func = self.registry.functions[req.function]
except KeyError:
raise ValueError(f"function {req.function} not found")

input = Input(req)
try:
with Runner() as runner:
output = runner.run(func._primitive_call(input))
output = asyncio.run(func._primitive_call(input))
except Exception:
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise # FIXME
Expand Down
101 changes: 40 additions & 61 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,20 @@ def read_root():
import fastapi.responses

from dispatch.function import Registry
from dispatch.http import (
FunctionServiceError,
function_service_run,
validate_content_length,
)
from dispatch.http import FunctionService, FunctionServiceError, validate_content_length
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)


class Dispatch(Registry):
class Dispatch(FunctionService):
"""A Dispatch instance, powered by FastAPI."""

def __init__(
self,
app: fastapi.FastAPI,
endpoint: Optional[str] = None,
registry: Optional[Registry] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a FastAPI app.
Expand All @@ -53,9 +47,8 @@ def __init__(
Args:
app: The FastAPI app to configure.
endpoint: Full URL of the application the Dispatch instance will
be running on. Uses the value of the DISPATCH_ENDPOINT_URL
environment variable by default.
registry: A registry of functions to expose. If omitted, the default
registry is used.
verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
Expand All @@ -64,63 +57,49 @@ def __init__(
If not set, request signature verification is disabled (a warning
will be logged by the constructor).
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.
api_url: The URL of the Dispatch API to use. Uses the value of the
DISPATCH_API_URL environment variable if set, otherwise
defaults to the public Dispatch API (DEFAULT_API_URL).
Raises:
ValueError: If any of the required arguments are missing.
"""
if not app:
raise ValueError(
"missing FastAPI app as first argument of the Dispatch constructor"
)
super().__init__(endpoint, api_key=api_key, api_url=api_url)
verification_key = parse_verification_key(verification_key, endpoint=endpoint)
function_service = _new_app(self, verification_key)
app.mount("/dispatch.sdk.v1.FunctionService", function_service)


def _new_app(function_registry: Registry, verification_key: Optional[Ed25519PublicKey]):
app = fastapi.FastAPI()

@app.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status, content={"code": exc.code, "message": exc.message}
)
super().__init__(registry, verification_key)
function_service = fastapi.FastAPI()

@function_service.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status,
content={"code": exc.code, "message": exc.message},
)

@app.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
)
async def execute(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

content = await function_service_run(
str(request.url),
request.method,
request.headers,
data,
function_registry,
verification_key,
@function_service.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
)
async def execute(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

content = await self.run(
str(request.url),
request.method,
request.headers,
await request.body(),
)

return fastapi.Response(content=content, media_type="application/proto")
return fastapi.Response(content=content, media_type="application/proto")

return app
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
Loading

0 comments on commit f1337e2

Please sign in to comment.