Skip to content

Commit

Permalink
fix PrimitiveFunction.__call__, the method cannot be async
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
  • Loading branch information
achille-roussel committed Jun 14, 2024
1 parent d185b38 commit f54119c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 34 deletions.
29 changes: 21 additions & 8 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,30 @@ def __init__(
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await dispatch.coroutine.call(self.build_call(*args, **kwargs))

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
async def _call_dispatch(self, *args: P.args, **kwargs: P.kwargs) -> T:
call = self.build_call(*args, **kwargs)
client = self.registry.client
[dispatch_id] = await client.dispatch([call])
return await client.wait(dispatch_id)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
"""Call the function asynchronously (through Dispatch), and return a
coroutine that can be awaited to retrieve the call result."""
# Note: this method cannot be made `async`, otherwise Python creates
# ont additional wrapping layer of native coroutine that cannot be
# pickled and breaks serialization.
#
# The durable coroutine returned by calling _func_indirect must be
# returned as is.
#
# For cases where this method is called outside the context of a
# dispatch function, it still returns a native coroutine object,
# but that doesn't matter since there is no state serialization in
# that case.
if in_function_call():
return await self._func_indirect(*args, **kwargs)

call = self.build_call(*args, **kwargs)

[dispatch_id] = await self.registry.client.dispatch([call])

return await self.registry.client.wait(dispatch_id)
return self._func_indirect(*args, **kwargs)
else:
return self._call_dispatch(*args, **kwargs)

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""Dispatch an asynchronous call to the function without
Expand Down
1 change: 1 addition & 0 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ async def _run(self, input: Input) -> Output:

# Serialize coroutines and scheduler state.
logger.debug("serializing state")
print("state", state)
try:
serialized_state = pickle.dumps(state)
except pickle.PickleError as e:
Expand Down
37 changes: 11 additions & 26 deletions src/dispatch/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def echo(name: str) -> str:


@_registry.function
async def echo2(name: str) -> str:
async def echo_nested(name: str) -> str:
return await echo(name)


Expand All @@ -362,6 +362,11 @@ def broken() -> str:
raise ValueError("something went wrong!")


@_registry.function
async def broken_nested(name: str) -> str:
return await broken()


set_default_registry(_registry)


Expand Down Expand Up @@ -481,34 +486,14 @@ async def test_call_two_functions(self):
self.assertEqual(await echo("hello"), "hello")
self.assertEqual(await length("hello"), 5)

# TODO:
#
# The declaration of nested functions in these tests causes CPython to
# generate cell objects since the local variables are referenced by multiple
# scopes.
#
# Maybe time to revisit https://github.com/dispatchrun/dispatch-py/pull/121
#
# Alternatively, we could rewrite the test suite to use a global registry
# where we register each function once in the globla scope, so no cells need
# to be created.

@aiotest
async def test_call_nested_function_with_result(self):
self.assertEqual(await echo2("hello"), "hello")
self.assertEqual(await echo_nested("hello"), "hello")

# @aiotest
# async def test_call_nested_function_with_error(self):
# @self.dispatch.function
# def broken_function(name: str) -> str:
# raise ValueError("something went wrong!")

# @self.dispatch.function
# async def working_function(name: str) -> str:
# return await broken_function(name)

# with self.assertRaises(ValueError) as e:
# await working_function("hello")
@aiotest
async def test_call_nested_function_with_error(self):
with self.assertRaises(ValueError) as e:
await broken_nested("hello")


class ClientRequestContentLengthMissing(aiohttp.ClientRequest):
Expand Down

0 comments on commit f54119c

Please sign in to comment.