Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code coverage #16

Merged
merged 15 commits into from
Feb 1, 2024
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
*.egg-info
__pycache__
.proto
.coverage
.coverage-html
12 changes: 11 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@ make dev
make test
```

## Code style
## Coverage

```
make coverage
```

In addition to displaying the summary in the terminal, this command generates an
HTML report with line-by-line coverage. `open .coverage-html/index.html` and
click around. You can refresh your browser after each `make coverage` run.

## Style

Formatting is done with `black`. Run `make fmt`.

Expand Down
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: install test typecheck unittest dev fmt generate clean update-proto
.PHONY: install test typecheck unittest dev fmt generate clean update-proto coverage

PYTHON := python

Expand All @@ -19,6 +19,11 @@ typecheck:
unittest:
$(PYTHON) -m unittest discover

coverage: typecheck
coverage run -m unittest discover
coverage html -d .coverage-html
coverage report

test: typecheck unittest

.proto:
Expand All @@ -42,5 +47,7 @@ generate: .proto/ring .proto/dispatch

clean:
rm -rf .proto
rm -rf .coverage
rm -rf .coverage-html
find . -type f -name '*.pyc' -delete
find . -type d -name '__pycache__' -exec rm -r {} \;
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,14 @@ dev = [
"black==24.1.0",
"mypy==1.8.0",
"fastapi==0.109.0",
"httpx==0.26.0"
"httpx==0.26.0",
"coverage==7.4.1"
]


[tool.coverage.run]
omit = [
"*_pb2_grpc.py",
"*_pb2.py",
"tests/*"
]
5 changes: 1 addition & 4 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,7 @@ def __init__(self, req: coroutine_pb2.ExecuteRequest):
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
if len(input_bytes) > 0:
self._input = pickle.loads(input_bytes)
else:
self._input = None
self._input = pickle.loads(input_bytes)
else:
state_bytes = req.poll_response.state
if len(state_bytes) > 0:
Expand Down
16 changes: 4 additions & 12 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,20 @@ async def execute(request: fastapi.Request):
coroutine = app._coroutines.get(uri, None)
if coroutine is None:
# TODO: integrate with logging
print("Coroutine not found:")
print(" uri:", uri)
print("Available coroutines:")
for k in app._coroutines:
print(" ", k)
raise KeyError(f"coroutine '{uri}' not available on this system")
raise fastapi.HTTPException(
status_code=404, detail=f"Coroutine URI '{uri}' does not exist"
)

coro_input = dispatch.coroutine.Input(req)

try:
output = coroutine(coro_input)
except Exception as ex:
# TODO: distinguish unaught exceptions from exceptions returned by
# TODO: distinguish uncaught exceptions from exceptions returned by
# coroutine?
err = dispatch.coroutine.Error.from_exception(ex)
output = dispatch.coroutine.Output.error(err)

if not isinstance(output, dispatch.coroutine.Output):
raise ValueError(
f"coroutine output should be an instance of {dispatch.coroutine.Output}, not {type(output)}"
)

resp = output._message
resp.coroutine_uri = req.coroutine_uri
resp.coroutine_version = req.coroutine_version
Expand Down
6 changes: 3 additions & 3 deletions tests/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _validate_authentication(self, context: grpc.ServicerContext):
return
context.abort(
grpc.StatusCode.UNAUTHENTICATED,
f"Invalid authorization header. Expected '{expected}', got '{value!r}'",
f"Invalid authorization header. Expected '{expected}', got {value!r}",
)
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Missing authorization header")

Expand Down Expand Up @@ -85,15 +85,15 @@ def __init__(self):
self.thread_pool = concurrent.futures.thread.ThreadPoolExecutor()
self.server = grpc.server(self.thread_pool)

port = self.server.add_insecure_port("127.0.0.1:0")
self.port = self.server.add_insecure_port("127.0.0.1:0")

self.servicer = FakeRing()

service_grpc.add_ServiceServicer_to_server(self.servicer, self.server)
self.server.start()

self.client = Client(
api_key=_test_auth_token, api_url=f"http://127.0.0.1:{port}"
api_key=_test_auth_token, api_url=f"http://127.0.0.1:{self.port}"
)

def stop(self):
Expand Down
40 changes: 39 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import unittest
from unittest import mock

import grpc
from google.protobuf import wrappers_pb2, any_pb2

from dispatch import Client, TaskInput, TaskID
Expand All @@ -17,6 +20,31 @@ def setUp(self):
def tearDown(self):
self.server.stop()

@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "WHATEVER"})
def test_api_key_from_env(self):
client = Client(api_url=f"http://127.0.0.1:{self.server.port}")

with self.assertRaises(grpc._channel._InactiveRpcError) as mc:
client.create_tasks(
[TaskInput(coroutine_uri="my-cool-coroutine", input=42)]
)
self.assertTrue("got 'Bearer WHATEVER'" in str(mc.exception))

def test_api_key_missing(self):
with self.assertRaises(ValueError) as mc:
client = Client()
self.assertEqual(str(mc.exception), "api_key is required")

def test_url_bad_scheme(self):
with self.assertRaises(ValueError) as mc:
client = Client(api_url="ftp://example.com", api_key="foo")
self.assertEqual(str(mc.exception), "Invalid API scheme: 'ftp'")

def test_can_be_constructed_on_https(self):
# Goal is to not raise an exception here. We don't have an HTTPS server
# around to actually test this.
Client(api_url="https://example.com", api_key="foo")

def test_create_one_task_pickle(self):
results = self.client.create_tasks(
[TaskInput(coroutine_uri="my-cool-coroutine", input=42)]
Expand Down Expand Up @@ -53,11 +81,21 @@ def test_create_one_task_proto_any(self):
proto_any = any_pb2.Any()
proto_any.Pack(proto)
results = self.client.create_tasks(
[TaskInput(coroutine_uri="my-cool-coroutine", input=proto)]
[TaskInput(coroutine_uri="my-cool-coroutine", input=proto_any)]
)
id = results[0]
created_tasks = self.servicer.created_tasks
entry = created_tasks[0]
task = entry["task"]
# proto any has not been modified
self.assertEqual(task.input, proto_any)


class TestTaskID(unittest.TestCase):
def test_string(self):
t = TaskID(partition_number=1, block_id=2, record_offset=3, record_size=4)
self.assertEqual(str(t), "00000001000000020000000300000004")

def test_repr(self):
t = TaskID(partition_number=1, block_id=2, record_offset=3, record_size=4)
self.assertEqual(repr(t), "TaskID(00000001000000020000000300000004)")
118 changes: 115 additions & 3 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pickle
import unittest
from typing import Any
import dispatch.coroutine
from dispatch.coroutine import Input, Output, Error, Status
import dispatch.fastapi

import httpx
import fastapi
from fastapi.testclient import TestClient
import google.protobuf.wrappers_pb2

import dispatch.fastapi
import dispatch.coroutine
from dispatch.coroutine import Input, Output, Error, Status
from ring.coroutine.v1 import coroutine_pb2
from . import executor_service

Expand Down Expand Up @@ -52,6 +55,11 @@ def test_configure_no_public_url(self):
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test", public_url="")

def test_configure_public_url_no_scheme(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test", public_url="127.0.0.1:9999")

def test_fastapi_simple_request(self):
app = fastapi.FastAPI()
dispatch.fastapi.configure(
Expand Down Expand Up @@ -159,6 +167,16 @@ def my_cool_coroutine(input: Input) -> Output:
out = response_output(resp)
self.assertEqual(out, "Hello World!")

def test_missing_coroutine(self):
req = coroutine_pb2.ExecuteRequest(
coroutine_uri="does-not-exist",
coroutine_version="1",
)

with self.assertRaises(httpx.HTTPStatusError) as cm:
self.client.Execute(req)
self.assertEqual(cm.exception.response.status_code, 404)

def test_string_input(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
Expand All @@ -168,6 +186,45 @@ def my_cool_coroutine(input: Input) -> Output:
out = response_output(resp)
self.assertEqual(out, "You sent 'cool stuff'")

def test_error_on_access_state_in_first_call(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
print(input.state)
return Output.value("not reached")

resp = self.execute(my_cool_coroutine, input="cool stuff")
self.assertEqual("ValueError", resp.exit.result.error.type)
self.assertEqual(
"This input is for a first coroutine call", resp.exit.result.error.message
)

def test_error_on_access_input_in_second_call(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
if input.is_first_call:
return Output.callback(state=42)
print(input.input)
return Output.value("not reached")

resp = self.execute(my_cool_coroutine, input="cool stuff")
resp = self.execute(my_cool_coroutine, state=resp.poll.state)

self.assertEqual("ValueError", resp.exit.result.error.type)
self.assertEqual(
"This input is for a resumed coroutine", resp.exit.result.error.message
)

def test_duplicate_coro(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
return Output.value("Do one thing")

with self.assertRaises(ValueError):

@self.app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
return Output.value("Do something else")

def test_two_simple_coroutines(self):
@self.app.dispatch_coroutine()
def echoroutine(input: Input) -> Output:
Expand Down Expand Up @@ -260,6 +317,43 @@ def coroutine_main(input: Input) -> Output:
out = response_output(resp)
self.assertEqual("length=10 text='cool stuff'", out)

def test_coroutine_poll_error(self):
@self.app.dispatch_coroutine()
def coro_compute_len(input: Input) -> Output:
return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead"))

@self.app.dispatch_coroutine()
def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
return Output.callback(
state=text, calls=[coro_compute_len.call_with(text)]
)
msg = input.calls[0].error.message
type = input.calls[0].error.type
return Output.value(f"msg={msg} type='{type}'")

resp = self.execute(coroutine_main, input="cool stuff")

# main saved some state
state = resp.poll.state
self.assertTrue(len(state) > 0)
# main asks for 1 call to compute_len
self.assertEqual(len(resp.poll.calls), 1)
call = resp.poll.calls[0]
self.assertEqual(call.coroutine_uri, coro_compute_len.uri)
self.assertEqual(dispatch.coroutine._any_unpickle(call.input), "cool stuff")

# make the requested compute_len
resp2 = self.call(call)

# resume main with the result
resp = self.execute(coroutine_main, state=state, calls=[resp2])
# validate the final result
self.assertTrue(len(resp.poll.state) == 0)
out = response_output(resp)
self.assertEqual(out, "msg=Dead type='type'")

def test_coroutine_error(self):
@self.app.dispatch_coroutine()
def mycoro(input: Input) -> Output:
Expand Down Expand Up @@ -318,3 +412,21 @@ def mycoro(input: Input) -> Output:
self.assertEqual(
42, dispatch.coroutine._any_unpickle(resp.exit.tail_call.input)
)


class TestError(unittest.TestCase):
def test_missing_type_and_message(self):
with self.assertRaises(ValueError):
Error(Status.TEMPORARY_ERROR, type=None, message=None)

def test_error_with_ok_status(self):
with self.assertRaises(ValueError):
Error(Status.OK, type="type", message="yep")

def test_from_exception_timeout(self):
err = Error.from_exception(TimeoutError())
self.assertEqual(Status.TIMEOUT, err.status)

def test_from_exception_syntax_error(self):
err = Error.from_exception(SyntaxError())
self.assertEqual(Status.PERMANENT_ERROR, err.status)
Loading