diff --git a/examples/auto_retry/test_app.py b/examples/auto_retry/test_app.py index ba3a440..fca4dc2 100644 --- a/examples/auto_retry/test_app.py +++ b/examples/auto_retry/test_app.py @@ -6,11 +6,10 @@ import unittest from unittest import mock -from fastapi.testclient import TestClient - from dispatch import Client from dispatch.sdk.v1 import status_pb2 as status_pb from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test.fastapi import http_client class TestAutoRetry(unittest.TestCase): @@ -25,14 +24,14 @@ def test_app(self): from .app import app, dispatch # Setup a fake Dispatch server. - endpoint_client = EndpointClient(TestClient(app)) + app_client = http_client(app) + endpoint_client = EndpointClient(app_client) dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. dispatch.set_client(Client(api_url=dispatch_server.url)) - http_client = TestClient(app) - response = http_client.get("/") + response = app_client.get("/") self.assertEqual(response.status_code, 200) dispatch_service.dispatch_calls() diff --git a/examples/getting_started/test_app.py b/examples/getting_started/test_app.py index 39e04ef..16a7f8c 100644 --- a/examples/getting_started/test_app.py +++ b/examples/getting_started/test_app.py @@ -6,10 +6,9 @@ import unittest from unittest import mock -from fastapi.testclient import TestClient - from dispatch import Client from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test.fastapi import http_client class TestGettingStarted(unittest.TestCase): @@ -24,14 +23,14 @@ def test_app(self): from .app import app, dispatch # Setup a fake Dispatch server. - endpoint_client = EndpointClient(TestClient(app)) + app_client = http_client(app) + endpoint_client = EndpointClient(app_client) dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. dispatch.set_client(Client(api_url=dispatch_server.url)) - http_client = TestClient(app) - response = http_client.get("/") + response = app_client.get("/") self.assertEqual(response.status_code, 200) dispatch_service.dispatch_calls() diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py index 08b7b24..4844016 100644 --- a/examples/github_stats/test_app.py +++ b/examples/github_stats/test_app.py @@ -6,10 +6,9 @@ import unittest from unittest import mock -from fastapi.testclient import TestClient - from dispatch.function import Client from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test.fastapi import http_client class TestGithubStats(unittest.TestCase): @@ -24,14 +23,14 @@ def test_app(self): from .app import app, dispatch # Setup a fake Dispatch server. - endpoint_client = EndpointClient(TestClient(app)) + app_client = http_client(app) + endpoint_client = EndpointClient(app_client) dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. dispatch.set_client(Client(api_url=dispatch_server.url)) - http_client = TestClient(app) - response = http_client.get("/") + response = app_client.get("/") self.assertEqual(response.status_code, 200) while dispatch_service.queue: diff --git a/pyproject.toml b/pyproject.toml index f55def8..8dc15e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,15 +15,16 @@ dependencies = [ "grpc-stubs >= 1.53.0.5", "http-message-signatures >= 0.4.4", "tblib >= 3.0.0", - "httpx >= 0.27.0", "typing_extensions >= 4.10" ] [project.optional-dependencies] fastapi = ["fastapi", "httpx"] +flask = ["flask"] lambda = ["awslambdaric"] dev = [ + "httpx >= 0.27.0", "black >= 24.1.0", "isort >= 5.13.2", "mypy >= 1.10.0", diff --git a/src/dispatch/test/client.py b/src/dispatch/test/client.py index 04d2fa9..b9f56ab 100644 --- a/src/dispatch/test/client.py +++ b/src/dispatch/test/client.py @@ -1,8 +1,7 @@ from datetime import datetime -from typing import Optional +from typing import Mapping, Optional, Protocol, Union import grpc -import httpx from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.sdk.v1 import function_pb2_grpc as function_grpc @@ -12,6 +11,7 @@ Request, sign_request, ) +from dispatch.test.http import HttpClient class EndpointClient: @@ -24,7 +24,7 @@ class EndpointClient: """ def __init__( - self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None + self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None ): """Initialize the client. @@ -32,7 +32,7 @@ def __init__( http_client: Client to use to make HTTP requests. signing_key: Optional Ed25519 private key to use to sign requests. """ - channel = _HttpxGrpcChannel(http_client, signing_key=signing_key) + channel = _HttpGrpcChannel(http_client, signing_key=signing_key) self._stub = function_grpc.FunctionServiceStub(channel) def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse: @@ -46,16 +46,10 @@ def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse: """ return self._stub.Run(request) - @classmethod - def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None): - """Returns an EndpointClient for a Dispatch endpoint URL.""" - http_client = httpx.Client(base_url=url) - return EndpointClient(http_client, signing_key) - -class _HttpxGrpcChannel(grpc.Channel): +class _HttpGrpcChannel(grpc.Channel): def __init__( - self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None + self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None ): self.http_client = http_client self.signing_key = signing_key @@ -120,9 +114,11 @@ def __call__( wait_for_ready=None, compression=None, ): + url = self.client.url_for(self.method) # note: method==path in gRPC parlance + request = Request( method="POST", - url=str(httpx.URL(self.client.base_url).join(self.method)), + url=url, body=self.request_serializer(request), headers=CaseInsensitiveDict({"Content-Type": "application/grpc+proto"}), ) @@ -131,10 +127,10 @@ def __call__( sign_request(request, self.signing_key, datetime.now()) response = self.client.post( - request.url, content=request.body, headers=request.headers + request.url, body=request.body, headers=request.headers ) response.raise_for_status() - return self.response_deserializer(response.content) + return self.response_deserializer(response.body) def with_call( self, diff --git a/src/dispatch/test/fastapi.py b/src/dispatch/test/fastapi.py new file mode 100644 index 0000000..381b180 --- /dev/null +++ b/src/dispatch/test/fastapi.py @@ -0,0 +1,10 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import dispatch.test.httpx +from dispatch.test.client import HttpClient + + +def http_client(app: FastAPI) -> HttpClient: + """Build a client for a FastAPI app.""" + return dispatch.test.httpx.Client(TestClient(app)) diff --git a/src/dispatch/test/http.py b/src/dispatch/test/http.py new file mode 100644 index 0000000..d811a76 --- /dev/null +++ b/src/dispatch/test/http.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Mapping, Protocol + + +@dataclass +class HttpResponse(Protocol): + status_code: int + body: bytes + + def raise_for_status(self): + """Raise an exception on non-2xx responses.""" + ... + + +class HttpClient(Protocol): + """Protocol for HTTP clients.""" + + def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: + """Make a GET request.""" + ... + + def post( + self, url: str, body: bytes, headers: Mapping[str, str] = {} + ) -> HttpResponse: + """Make a POST request.""" + ... + + def url_for(self, path: str) -> str: + """Get the fully-qualified URL for a path.""" + ... diff --git a/src/dispatch/test/httpx.py b/src/dispatch/test/httpx.py new file mode 100644 index 0000000..9d9f7c5 --- /dev/null +++ b/src/dispatch/test/httpx.py @@ -0,0 +1,39 @@ +from typing import Mapping + +import httpx + +from dispatch.test.http import HttpClient, HttpResponse + + +class Client(HttpClient): + def __init__(self, client: httpx.Client): + self.client = client + + def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: + response = self.client.get(url, headers=headers) + return Response(response) + + def post( + self, url: str, body: bytes, headers: Mapping[str, str] = {} + ) -> HttpResponse: + response = self.client.post(url, content=body, headers=headers) + return Response(response) + + def url_for(self, path: str) -> str: + return str(httpx.URL(self.client.base_url).join(path)) + + +class Response(HttpResponse): + def __init__(self, response: httpx.Response): + self.response = response + + @property + def status_code(self): + return self.response.status_code + + @property + def body(self): + return self.response.content + + def raise_for_status(self): + self.response.raise_for_status() diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index 5edf397..195c4d1 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Set, Tuple import grpc -import httpx from typing_extensions import TypeAlias import dispatch.sdk.v1.call_pb2 as call_pb @@ -325,17 +324,6 @@ def _dispatch_continuously(self): try: self.dispatch_calls() - except httpx.HTTPStatusError as e: - if e.response.status_code == 403: - logger.error( - "error dispatching function call to endpoint (403). Is the endpoint's DISPATCH_VERIFICATION_KEY correct?" - ) - else: - logger.exception(e) - except httpx.ConnectError as e: - logger.error( - "error connecting to the endpoint. Is it running and accessible from DISPATCH_ENDPOINT_URL?" - ) except Exception as e: logger.exception(e) diff --git a/tests/test_client.py b/tests/test_client.py index 09f96b5..c04945b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,9 @@ import unittest from unittest import mock +import httpx + +import dispatch.test.httpx from dispatch import Call, Client from dispatch.proto import _any_unpickle as any_unpickle from dispatch.test import DispatchServer, DispatchService, EndpointClient @@ -9,7 +12,10 @@ class TestClient(unittest.TestCase): def setUp(self): - endpoint_client = EndpointClient.from_url("http://function-service") + http_client = dispatch.test.httpx.Client( + httpx.Client(base_url="http://function-service") + ) + endpoint_client = EndpointClient(http_client) api_key = "0000000000000000" self.dispatch_service = DispatchService(endpoint_client, api_key) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 5c2135d..dee353f 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -30,6 +30,7 @@ ) from dispatch.status import Status from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test.fastapi import http_client def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): @@ -44,8 +45,7 @@ def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): def create_endpoint_client( app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None ): - http_client = TestClient(app) - return EndpointClient(http_client, signing_key) + return EndpointClient(http_client(app), signing_key) class TestFastAPI(unittest.TestCase): diff --git a/tests/test_http.py b/tests/test_http.py index 21e8b09..c5f0c0f 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -14,6 +14,7 @@ import httpx from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey +import dispatch.test.httpx from dispatch.experimental.durable.registry import clear_functions from dispatch.function import Arguments, Error, Function, Input, Output, Registry from dispatch.http import Dispatch @@ -87,7 +88,8 @@ def my_function(input: Input) -> Output: f"You told me: '{input.input}' ({len(input.input)} characters)" ) - client = EndpointClient.from_url(self.endpoint) + http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) + client = EndpointClient(http_client) pickled = pickle.dumps("Hello World!") input_any = google.protobuf.any_pb2.Any()