Skip to content

Commit

Permalink
Improve decorater error handling (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
janbjorge authored Feb 20, 2024
1 parent 99d1f58 commit 5b49846
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 41 deletions.
68 changes: 29 additions & 39 deletions src/pgcachewatch/decorators.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,65 @@
import asyncio
import contextlib
import logging
import typing
from typing import Awaitable, Callable, Hashable, Literal, TypeVar

import typing_extensions
from typing_extensions import ParamSpec

from pgcachewatch import strategies, utils

P = typing_extensions.ParamSpec("P")
T = typing.TypeVar("T")
P = ParamSpec("P")
T = TypeVar("T")


def cache(
strategy: strategies.Strategy,
statistics_callback: typing.Callable[[typing.Literal["hit", "miss"]], None]
| None = None,
) -> typing.Callable[
[typing.Callable[P, typing.Awaitable[T]]],
typing.Callable[P, typing.Awaitable[T]],
]:
def outer(
fn: typing.Callable[P, typing.Awaitable[T]],
) -> typing.Callable[P, typing.Awaitable[T]]:
cached = dict[typing.Hashable, asyncio.Future[T]]()
statistics_callback: Callable[[Literal["hit", "miss"]], None] | None = None,
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
def outer(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
cached = dict[Hashable, asyncio.Future[T]]()

async def inner(*args: P.args, **kw: P.kwargs) -> T:
async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
# If db-conn is down, disable cache.
if not strategy.pg_connection_healthy():
logging.critical("Database connection is closed, caching disabled.")
return await fn(*args, **kw)
return await fn(*args, **kwargs)

# Clear cache if we have a event from
# the database the instructs us to clear.
if strategy.clear():
logging.debug("Cache clear")
cached.clear()

# Check for cache hit
key = utils.make_key(args, kw)
with contextlib.suppress(KeyError):
# OBS: Will only await if the cache key hits.
result = await cached[key]
key = utils.make_key(args, kwargs)

try:
waiter = cached[key]
except KeyError:
# Cache miss
...
else:
# Cache hit
logging.debug("Cache hit")
if statistics_callback:
statistics_callback("hit")
return result
return await waiter

# Below deals with a cache miss.
logging.debug("Cache miss")
if statistics_callback:
statistics_callback("miss")

# By using a future as placeholder we avoid
# cache stampeded. Note that on the "miss" branch/path, controll
# is never given to the eventloopscheduler before the future
# is create.
# Initialize Future to prevent cache stampedes.
cached[key] = waiter = asyncio.Future[T]()

try:
result = await fn(*args, **kw)
# # Attempt to compute result and set for waiter
waiter.set_result(await fn(*args, **kwargs))
except Exception as e:
cached.pop(
key, None
) # Next try should not result in a repeating exception
waiter.set_exception(
e
) # Propegate exception to other callers who are waiting.
raise e from None # Propegate exception to first caller.
else:
waiter.set_result(result)
# Remove key from cache on failure.
cached.pop(key, None)
# Propagate exception to all awaiting the future.
waiter.set_exception(e)

return result
return await waiter

return inner

Expand Down
54 changes: 52 additions & 2 deletions tests/test_decoraters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import collections
import datetime
from typing import NoReturn

import asyncpg
import pytest
from pgcachewatch import decorators, listeners, models, strategies


@pytest.mark.parametrize("N", (4, 16, 64, 512))
@pytest.mark.parametrize("N", (1, 2, 4, 16, 64))
async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> None:
statistics = collections.Counter[str]()
listener = listeners.PGEventQueue()
Expand All @@ -20,6 +21,55 @@ async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> Non
async def now() -> datetime.datetime:
return datetime.datetime.now()

await asyncio.gather(*[now() for _ in range(N)])
nows = set(await asyncio.gather(*[now() for _ in range(N)]))
assert len(nows) == 1

assert statistics["hit"] == N - 1
assert statistics["miss"] == 1


@pytest.mark.parametrize("N", (1, 2, 4, 16, 64))
async def test_gready_cache_decorator_connection_closed(
N: int,
pgconn: asyncpg.Connection,
) -> None:
listener = listeners.PGEventQueue()
await listener.connect(
pgconn,
models.PGChannel("test_gready_cache_decorator_connection_closed"),
)
await pgconn.close()

@decorators.cache(strategy=strategies.Gready(listener=listener))
async def now() -> datetime.datetime:
return datetime.datetime.now()

nows = await asyncio.gather(*[now() for _ in range(N)])
assert len(set(nows)) == N


@pytest.mark.parametrize("N", (1, 2, 4, 16, 64))
async def test_gready_cache_decorator_exceptions(
N: int,
pgconn: asyncpg.Connection,
) -> None:
listener = listeners.PGEventQueue()
await listener.connect(
pgconn,
models.PGChannel("test_gready_cache_decorator_exceptions"),
)

@decorators.cache(strategy=strategies.Gready(listener=listener))
async def raise_runtime_error() -> NoReturn:
raise RuntimeError

for _ in range(N):
with pytest.raises(RuntimeError):
await raise_runtime_error()

exceptions = await asyncio.gather(
*[raise_runtime_error() for _ in range(N)],
return_exceptions=True,
)
assert len(exceptions) == N
assert all(isinstance(exc, RuntimeError) for exc in exceptions)

0 comments on commit 5b49846

Please sign in to comment.