Skip to content

Commit

Permalink
Merge pull request #100 from stealthrocket/paramspec
Browse files Browse the repository at this point in the history
Improve type safety with ParamSpec
  • Loading branch information
chriso authored Mar 3, 2024
2 parents 0a518bc + 51f0023 commit 78c0ced
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 73 deletions.
Empty file added src/buf/validate/py.typed
Empty file.
Empty file.
94 changes: 44 additions & 50 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import inspect
import logging
from functools import wraps
from types import FunctionType
from typing import Any, Callable, Dict, TypeAlias
from types import CoroutineType
from typing import Any, Callable, Dict, Generic, ParamSpec, TypeAlias, TypeVar

import dispatch.coroutine
from dispatch.client import Client
Expand All @@ -23,29 +23,11 @@
"""


# https://stackoverflow.com/questions/653368/how-to-create-a-decorator-that-can-be-used-either-with-or-without-parameters
def decorator(f):
"""This decorator is intended to declare decorators that can be used with
or without parameters. If the decorated function is called with a single
callable argument, it is assumed to be a function and the decorator is
applied to it. Otherwise, the decorator is called with the arguments
provided and the result is returned.
"""

@wraps(f)
def method(self, *args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return f(self, args[0])

def wrapper(func):
return f(self, func, *args, **kwargs)

return wrapper
P = ParamSpec("P")
T = TypeVar("T")

return method


class Function:
class Function(Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""
Expand All @@ -58,18 +40,23 @@ def __init__(
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
func: Callable[P, T] | None,
coroutine: bool = False,
):
self._endpoint = endpoint
self._client = client
self._name = name
self._primitive_func = primitive_func
# FIXME: is there a way to decorate the function at the definition
# without making it a class method?
self._func = durable(self._call_async) if coroutine else func
if func:
self._func: Callable[P, T] | None = (
durable(self._call_async) if coroutine else func
)
else:
self._func = None

def __call__(self, *args, **kwargs):
def __call__(self, *args: P.args, **kwargs: P.kwargs):
if self._func is None:
raise ValueError("cannot call a primitive function directly")
return self._func(*args, **kwargs)

def _primitive_call(self, input: Input) -> Output:
Expand All @@ -83,7 +70,7 @@ def endpoint(self) -> str:
def name(self) -> str:
return self._name

def dispatch(self, *args: Any, **kwargs: Any) -> DispatchID:
def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""Dispatch a call to the function.
The Registry this function was registered with must be initialized
Expand All @@ -105,14 +92,14 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
return dispatch_id

async def _call_async(self, *args, **kwargs) -> Any:
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""Asynchronously call the function from a @dispatch.function."""
return await dispatch.coroutine.call(
self.build_call(*args, **kwargs, correlation_id=None)
)

def build_call(
self, *args: Any, correlation_id: int | None = None, **kwargs: Any
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
) -> Call:
"""Create a Call for this function with the provided input. Useful to
generate calls when using the Client.
Expand Down Expand Up @@ -158,24 +145,21 @@ def __init__(self, endpoint: str, client: Client):
self._endpoint = endpoint
self._client = client

@decorator
def function(self, func: Callable) -> Function:
"""Returns a decorator that registers functions."""
def function(self, func: Callable[P, T]) -> Function[P, T]:
"""Decorator that registers functions."""
if inspect.iscoroutinefunction(func):
return self._register_coroutine(func)
return self._register_function(func)

@decorator
def primitive_function(self, func: Callable) -> Function:
"""Returns a decorator that registers primitive functions."""
def primitive_function(self, func: PrimitiveFunctionType) -> Function:
"""Decorator that registers primitive functions."""
return self._register_primitive_function(func)

def _register_function(self, func: Callable) -> Function:
if inspect.iscoroutinefunction(func):
return self._register_coroutine(func)

def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
logger.info("registering function: %s", func.__qualname__)

# Register the function with the experimental.durable package, in case
# it's referenced from a @dispatch.coroutine.
# it's referenced from a coroutine.
func = durable(func)

@wraps(func)
Expand All @@ -199,7 +183,9 @@ def primitive_func(input: Input) -> Output:

return self._register(primitive_func, func, coroutine=False)

def _register_coroutine(self, func: Callable) -> Function:
def _register_coroutine(
self, func: Callable[P, CoroutineType[Any, Any, T]]
) -> Function[P, T]:
logger.info("registering coroutine: %s", func.__qualname__)

func = durable(func)
Expand All @@ -213,19 +199,27 @@ def primitive_func(input: Input) -> Output:

return self._register(primitive_func, func, coroutine=True)

def _register_primitive_function(self, func: PrimitiveFunctionType) -> Function:
logger.info("registering primitive function: %s", func.__qualname__)
return self._register(func, func, coroutine=inspect.iscoroutinefunction(func))
def _register_primitive_function(
self, primitive_func: PrimitiveFunctionType
) -> Function[P, T]:
logger.info("registering primitive function: %s", primitive_func.__qualname__)
return self._register(primitive_func, func=None, coroutine=False)

def _register(
self, primitive_func: PrimitiveFunctionType, func: Callable, coroutine: bool
) -> Function:
name = func.__qualname__
self,
primitive_func: PrimitiveFunctionType,
func: Callable[P, T] | None,
coroutine: bool,
) -> Function[P, T]:
if func:
name = func.__qualname__
else:
name = primitive_func.__qualname__
if name in self._functions:
raise ValueError(
f"function or coroutine already registered with name '{name}'"
)
wrapped_func = Function(
wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func, coroutine
)
self._functions[name] = wrapped_func
Expand Down
Empty file.
Empty file added src/dispatch/py.typed
Empty file.
Empty file added src/dispatch/sdk/v1/py.typed
Empty file.
2 changes: 1 addition & 1 deletion tests/dispatch/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def setUp(self):
self.dispatch = Registry(endpoint="http://example.com", client=self.client)

def test_serializable(self):
@self.dispatch.function()
@self.dispatch.function
def my_function():
pass

Expand Down
44 changes: 22 additions & 22 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_fastapi_simple_request(self):
app = fastapi.FastAPI()
dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/")

@dispatch.primitive_function()
@dispatch.primitive_function
def my_function(input: Input) -> Output:
return Output.value(
f"You told me: '{input.input}' ({len(input.input)} characters)"
Expand Down Expand Up @@ -159,7 +159,7 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult:
return resp.exit.result

def test_no_input(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
return Output.value("Hello World!")

Expand All @@ -178,7 +178,7 @@ def test_missing_coroutine(self):
self.assertEqual(cm.exception.response.status_code, 404)

def test_string_input(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
return Output.value(f"You sent '{input.input}'")

Expand All @@ -187,7 +187,7 @@ def my_function(input: Input) -> Output:
self.assertEqual(out, "You sent 'cool stuff'")

def test_error_on_access_state_in_first_call(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
try:
print(input.coroutine_state)
Expand All @@ -206,7 +206,7 @@ def my_function(input: Input) -> Output:
)

def test_error_on_access_input_in_second_call(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
if input.is_first_call:
return Output.poll(state=42)
Expand All @@ -230,22 +230,22 @@ def my_function(input: Input) -> Output:
)

def test_duplicate_coro(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
return Output.value("Do one thing")

with self.assertRaises(ValueError):

@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
return Output.value("Do something else")

def test_two_simple_coroutines(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def echoroutine(input: Input) -> Output:
return Output.value(f"Echo: '{input.input}'")

@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def len_coroutine(input: Input) -> Output:
return Output.value(f"Length: {len(input.input)}")

Expand All @@ -259,7 +259,7 @@ def len_coroutine(input: Input) -> Output:
self.assertEqual(out, "Length: 10")

def test_coroutine_with_state(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def coroutine3(input: Input) -> Output:
if input.is_first_call:
counter = input.input
Expand Down Expand Up @@ -293,11 +293,11 @@ def coroutine3(input: Input) -> Output:
self.assertEqual(out, "done")

def test_coroutine_poll(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def coro_compute_len(input: Input) -> Output:
return Output.value(len(input.input))

@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
Expand Down Expand Up @@ -333,11 +333,11 @@ def coroutine_main(input: Input) -> Output:
self.assertEqual("length=10 text='cool stuff'", out)

def test_coroutine_poll_error(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def coro_compute_len(input: Input) -> Output:
return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead"))

@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
Expand Down Expand Up @@ -372,7 +372,7 @@ def coroutine_main(input: Input) -> Output:
self.assertEqual(out, "msg=Dead type='type'")

def test_coroutine_error(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def mycoro(input: Input) -> Output:
return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead"))

Expand All @@ -381,7 +381,7 @@ def mycoro(input: Input) -> Output:
self.assertEqual("dead", resp.exit.result.error.message)

def test_coroutine_expected_exception(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def mycoro(input: Input) -> Output:
try:
1 / 0
Expand All @@ -395,7 +395,7 @@ def mycoro(input: Input) -> Output:
self.assertEqual(Status.PERMANENT_ERROR, resp.status)

def test_coroutine_unexpected_exception(self):
@self.dispatch.function()
@self.dispatch.function
def mycoro():
1 / 0
self.fail("should not reach here")
Expand All @@ -406,7 +406,7 @@ def mycoro():
self.assertEqual(Status.PERMANENT_ERROR, resp.status)

def test_specific_status(self):
@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def mycoro(input: Input) -> Output:
return Output.error(Error(Status.THROTTLED, "foo", "bar"))

Expand All @@ -416,11 +416,11 @@ def mycoro(input: Input) -> Output:
self.assertEqual(Status.THROTTLED, resp.status)

def test_tailcall(self):
@self.dispatch.function()
@self.dispatch.function
def other_coroutine(value: Any) -> str:
return f"Hello {value}"

@self.dispatch.primitive_function()
@self.dispatch.primitive_function
def mycoro(input: Input) -> Output:
return Output.tail_call(other_coroutine._build_primitive_call(42))

Expand All @@ -429,7 +429,7 @@ def mycoro(input: Input) -> Output:
self.assertEqual(42, any_unpickle(resp.exit.tail_call.input))

def test_library_error_categorization(self):
@self.dispatch.function()
@self.dispatch.function
def get(path: str) -> httpx.Response:
http_response = self.http_client.get(path)
http_response.raise_for_status()
Expand All @@ -445,7 +445,7 @@ def get(path: str) -> httpx.Response:
self.assertEqual(Status.NOT_FOUND, Status(resp.status))

def test_library_output_categorization(self):
@self.dispatch.function()
@self.dispatch.function
def get(path: str) -> httpx.Response:
http_response = self.http_client.get(path)
http_response.status_code = 429
Expand Down

0 comments on commit 78c0ced

Please sign in to comment.