diff --git a/.gitignore b/.gitignore index 8cce4a54..e1dea320 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ *.egg-info __pycache__ .proto +.coverage +.coverage-html diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d8c1e075..c692499d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,17 @@ make dev make test ``` -## Code style +## Coverage + +``` +make coverage +``` + +In addition to displaying the summary in the terminal, this command generates an +HTML report with line-by-line coverage. `open .coverage-html/index.html` and +click around. You can refresh your browser after each `make coverage` run. + +## Style Formatting is done with `black`. Run `make fmt`. diff --git a/Makefile b/Makefile index c6beea34..1daae54d 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: install test typecheck unittest dev fmt generate clean update-proto +.PHONY: install test typecheck unittest dev fmt generate clean update-proto coverage PYTHON := python @@ -19,6 +19,11 @@ typecheck: unittest: $(PYTHON) -m unittest discover +coverage: typecheck + coverage run -m unittest discover + coverage html -d .coverage-html + coverage report + test: typecheck unittest .proto: @@ -42,5 +47,7 @@ generate: .proto/ring .proto/dispatch clean: rm -rf .proto + rm -rf .coverage + rm -rf .coverage-html find . -type f -name '*.pyc' -delete find . -type d -name '__pycache__' -exec rm -r {} \; diff --git a/pyproject.toml b/pyproject.toml index 0f86dc6d..00ec8744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,5 +19,14 @@ dev = [ "black==24.1.0", "mypy==1.8.0", "fastapi==0.109.0", - "httpx==0.26.0" + "httpx==0.26.0", + "coverage==7.4.1" ] + + +[tool.coverage.run] +omit = [ + "*_pb2_grpc.py", + "*_pb2.py", + "tests/*" +] \ No newline at end of file diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index b360e5ab..50d19883 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -133,10 +133,7 @@ def __init__(self, req: coroutine_pb2.ExecuteRequest): input_pb = google.protobuf.wrappers_pb2.BytesValue() req.input.Unpack(input_pb) input_bytes = input_pb.value - if len(input_bytes) > 0: - self._input = pickle.loads(input_bytes) - else: - self._input = None + self._input = pickle.loads(input_bytes) else: state_bytes = req.poll_response.state if len(state_bytes) > 0: diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index f0929973..59278257 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -132,28 +132,20 @@ async def execute(request: fastapi.Request): coroutine = app._coroutines.get(uri, None) if coroutine is None: # TODO: integrate with logging - print("Coroutine not found:") - print(" uri:", uri) - print("Available coroutines:") - for k in app._coroutines: - print(" ", k) - raise KeyError(f"coroutine '{uri}' not available on this system") + raise fastapi.HTTPException( + status_code=404, detail=f"Coroutine URI '{uri}' does not exist" + ) coro_input = dispatch.coroutine.Input(req) try: output = coroutine(coro_input) except Exception as ex: - # TODO: distinguish unaught exceptions from exceptions returned by + # TODO: distinguish uncaught exceptions from exceptions returned by # coroutine? err = dispatch.coroutine.Error.from_exception(ex) output = dispatch.coroutine.Output.error(err) - if not isinstance(output, dispatch.coroutine.Output): - raise ValueError( - f"coroutine output should be an instance of {dispatch.coroutine.Output}, not {type(output)}" - ) - resp = output._message resp.coroutine_uri = req.coroutine_uri resp.coroutine_version = req.coroutine_version diff --git a/tests/task_service.py b/tests/task_service.py index f83779ed..b8558b35 100644 --- a/tests/task_service.py +++ b/tests/task_service.py @@ -32,7 +32,7 @@ def _validate_authentication(self, context: grpc.ServicerContext): return context.abort( grpc.StatusCode.UNAUTHENTICATED, - f"Invalid authorization header. Expected '{expected}', got '{value!r}'", + f"Invalid authorization header. Expected '{expected}', got {value!r}", ) context.abort(grpc.StatusCode.UNAUTHENTICATED, "Missing authorization header") @@ -85,7 +85,7 @@ def __init__(self): self.thread_pool = concurrent.futures.thread.ThreadPoolExecutor() self.server = grpc.server(self.thread_pool) - port = self.server.add_insecure_port("127.0.0.1:0") + self.port = self.server.add_insecure_port("127.0.0.1:0") self.servicer = FakeRing() @@ -93,7 +93,7 @@ def __init__(self): self.server.start() self.client = Client( - api_key=_test_auth_token, api_url=f"http://127.0.0.1:{port}" + api_key=_test_auth_token, api_url=f"http://127.0.0.1:{self.port}" ) def stop(self): diff --git a/tests/test_client.py b/tests/test_client.py index 23e915b4..a27be3c8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,8 @@ +import os import unittest +from unittest import mock +import grpc from google.protobuf import wrappers_pb2, any_pb2 from dispatch import Client, TaskInput, TaskID @@ -17,6 +20,31 @@ def setUp(self): def tearDown(self): self.server.stop() + @mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "WHATEVER"}) + def test_api_key_from_env(self): + client = Client(api_url=f"http://127.0.0.1:{self.server.port}") + + with self.assertRaises(grpc._channel._InactiveRpcError) as mc: + client.create_tasks( + [TaskInput(coroutine_uri="my-cool-coroutine", input=42)] + ) + self.assertTrue("got 'Bearer WHATEVER'" in str(mc.exception)) + + def test_api_key_missing(self): + with self.assertRaises(ValueError) as mc: + client = Client() + self.assertEqual(str(mc.exception), "api_key is required") + + def test_url_bad_scheme(self): + with self.assertRaises(ValueError) as mc: + client = Client(api_url="ftp://example.com", api_key="foo") + self.assertEqual(str(mc.exception), "Invalid API scheme: 'ftp'") + + def test_can_be_constructed_on_https(self): + # Goal is to not raise an exception here. We don't have an HTTPS server + # around to actually test this. + Client(api_url="https://example.com", api_key="foo") + def test_create_one_task_pickle(self): results = self.client.create_tasks( [TaskInput(coroutine_uri="my-cool-coroutine", input=42)] @@ -53,7 +81,7 @@ def test_create_one_task_proto_any(self): proto_any = any_pb2.Any() proto_any.Pack(proto) results = self.client.create_tasks( - [TaskInput(coroutine_uri="my-cool-coroutine", input=proto)] + [TaskInput(coroutine_uri="my-cool-coroutine", input=proto_any)] ) id = results[0] created_tasks = self.servicer.created_tasks @@ -61,3 +89,13 @@ def test_create_one_task_proto_any(self): task = entry["task"] # proto any has not been modified self.assertEqual(task.input, proto_any) + + +class TestTaskID(unittest.TestCase): + def test_string(self): + t = TaskID(partition_number=1, block_id=2, record_offset=3, record_size=4) + self.assertEqual(str(t), "00000001000000020000000300000004") + + def test_repr(self): + t = TaskID(partition_number=1, block_id=2, record_offset=3, record_size=4) + self.assertEqual(repr(t), "TaskID(00000001000000020000000300000004)") diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index eaf221d6..65d04b14 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,12 +1,15 @@ import pickle import unittest from typing import Any -import dispatch.coroutine -from dispatch.coroutine import Input, Output, Error, Status -import dispatch.fastapi + +import httpx import fastapi from fastapi.testclient import TestClient import google.protobuf.wrappers_pb2 + +import dispatch.fastapi +import dispatch.coroutine +from dispatch.coroutine import Input, Output, Error, Status from ring.coroutine.v1 import coroutine_pb2 from . import executor_service @@ -52,6 +55,11 @@ def test_configure_no_public_url(self): with self.assertRaises(ValueError): dispatch.fastapi.configure(app, api_key="test", public_url="") + def test_configure_public_url_no_scheme(self): + app = fastapi.FastAPI() + with self.assertRaises(ValueError): + dispatch.fastapi.configure(app, api_key="test", public_url="127.0.0.1:9999") + def test_fastapi_simple_request(self): app = fastapi.FastAPI() dispatch.fastapi.configure( @@ -159,6 +167,16 @@ def my_cool_coroutine(input: Input) -> Output: out = response_output(resp) self.assertEqual(out, "Hello World!") + def test_missing_coroutine(self): + req = coroutine_pb2.ExecuteRequest( + coroutine_uri="does-not-exist", + coroutine_version="1", + ) + + with self.assertRaises(httpx.HTTPStatusError) as cm: + self.client.Execute(req) + self.assertEqual(cm.exception.response.status_code, 404) + def test_string_input(self): @self.app.dispatch_coroutine() def my_cool_coroutine(input: Input) -> Output: @@ -168,6 +186,45 @@ def my_cool_coroutine(input: Input) -> Output: out = response_output(resp) self.assertEqual(out, "You sent 'cool stuff'") + def test_error_on_access_state_in_first_call(self): + @self.app.dispatch_coroutine() + def my_cool_coroutine(input: Input) -> Output: + print(input.state) + return Output.value("not reached") + + resp = self.execute(my_cool_coroutine, input="cool stuff") + self.assertEqual("ValueError", resp.exit.result.error.type) + self.assertEqual( + "This input is for a first coroutine call", resp.exit.result.error.message + ) + + def test_error_on_access_input_in_second_call(self): + @self.app.dispatch_coroutine() + def my_cool_coroutine(input: Input) -> Output: + if input.is_first_call: + return Output.callback(state=42) + print(input.input) + return Output.value("not reached") + + resp = self.execute(my_cool_coroutine, input="cool stuff") + resp = self.execute(my_cool_coroutine, state=resp.poll.state) + + self.assertEqual("ValueError", resp.exit.result.error.type) + self.assertEqual( + "This input is for a resumed coroutine", resp.exit.result.error.message + ) + + def test_duplicate_coro(self): + @self.app.dispatch_coroutine() + def my_cool_coroutine(input: Input) -> Output: + return Output.value("Do one thing") + + with self.assertRaises(ValueError): + + @self.app.dispatch_coroutine() + def my_cool_coroutine(input: Input) -> Output: + return Output.value("Do something else") + def test_two_simple_coroutines(self): @self.app.dispatch_coroutine() def echoroutine(input: Input) -> Output: @@ -260,6 +317,43 @@ def coroutine_main(input: Input) -> Output: out = response_output(resp) self.assertEqual("length=10 text='cool stuff'", out) + def test_coroutine_poll_error(self): + @self.app.dispatch_coroutine() + def coro_compute_len(input: Input) -> Output: + return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) + + @self.app.dispatch_coroutine() + def coroutine_main(input: Input) -> Output: + if input.is_first_call: + text: str = input.input + return Output.callback( + state=text, calls=[coro_compute_len.call_with(text)] + ) + msg = input.calls[0].error.message + type = input.calls[0].error.type + return Output.value(f"msg={msg} type='{type}'") + + resp = self.execute(coroutine_main, input="cool stuff") + + # main saved some state + state = resp.poll.state + self.assertTrue(len(state) > 0) + # main asks for 1 call to compute_len + self.assertEqual(len(resp.poll.calls), 1) + call = resp.poll.calls[0] + self.assertEqual(call.coroutine_uri, coro_compute_len.uri) + self.assertEqual(dispatch.coroutine._any_unpickle(call.input), "cool stuff") + + # make the requested compute_len + resp2 = self.call(call) + + # resume main with the result + resp = self.execute(coroutine_main, state=state, calls=[resp2]) + # validate the final result + self.assertTrue(len(resp.poll.state) == 0) + out = response_output(resp) + self.assertEqual(out, "msg=Dead type='type'") + def test_coroutine_error(self): @self.app.dispatch_coroutine() def mycoro(input: Input) -> Output: @@ -318,3 +412,21 @@ def mycoro(input: Input) -> Output: self.assertEqual( 42, dispatch.coroutine._any_unpickle(resp.exit.tail_call.input) ) + + +class TestError(unittest.TestCase): + def test_missing_type_and_message(self): + with self.assertRaises(ValueError): + Error(Status.TEMPORARY_ERROR, type=None, message=None) + + def test_error_with_ok_status(self): + with self.assertRaises(ValueError): + Error(Status.OK, type="type", message="yep") + + def test_from_exception_timeout(self): + err = Error.from_exception(TimeoutError()) + self.assertEqual(Status.TIMEOUT, err.status) + + def test_from_exception_syntax_error(self): + err = Error.from_exception(SyntaxError()) + self.assertEqual(Status.PERMANENT_ERROR, err.status)