From 5e34512448bf037d1d002ebe162e415e707054b5 Mon Sep 17 00:00:00 2001 From: Andrew Smith Date: Thu, 26 Sep 2024 11:33:44 +0000 Subject: [PATCH] fix: async set_auth for realtime in auth event listener (#930) --- Makefile | 1 + poetry.lock | 20 ++++++++++++++++++- pyproject.toml | 4 ++++ supabase/_async/client.py | 5 ++--- supabase/_sync/client.py | 3 --- tests/_async/test_client.py | 38 +++++++++++++++++++++++++++++++++++++ tests/test_client.py | 2 -- 7 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 tests/_async/test_client.py diff --git a/Makefile b/Makefile index a45e19b2..448d6a62 100644 --- a/Makefile +++ b/Makefile @@ -17,3 +17,4 @@ tests_only: build_sync: poetry run unasync supabase tests + sed -i 's/asyncio.create_task(self.realtime.set_auth(access_token))//g' supabase/_sync/client.py diff --git a/poetry.lock b/poetry.lock index 021bbb18..bc815a85 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1336,6 +1336,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "5.0.0" @@ -1905,4 +1923,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8e4ee6045fed0d07c502f1539a3067d2af6486089dfa7a2df0bc58006e867a32" +content-hash = "cea80a29f0f6d0c9c447763bd4e339fd0be65a0be0c9089c3ce62eda7939523e" diff --git a/pyproject.toml b/pyproject.toml index ece89058..610970e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,10 @@ tests = 'poetry_scripts:run_tests' [tool.poetry.group.dev.dependencies] unasync-cli = { git = "https://github.com/supabase-community/unasync-cli.git", branch = "main" } +pytest-asyncio = "^0.24.0" + +[tool.pytest.ini_options] +asyncio_mode = "auto" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/supabase/_async/client.py b/supabase/_async/client.py index 48ba994b..ef2cefd7 100644 --- a/supabase/_async/client.py +++ b/supabase/_async/client.py @@ -1,3 +1,4 @@ +import asyncio import re from typing import Any, Dict, List, Optional, Union @@ -296,9 +297,7 @@ def _listen_to_auth_events( access_token = session.access_token if session else self.supabase_key self.options.headers["Authorization"] = self._create_auth_header(access_token) - - # set_auth is a coroutine, how to handle this? - self.realtime.set_auth(access_token) + asyncio.create_task(self.realtime.set_auth(access_token)) async def create_client( diff --git a/supabase/_sync/client.py b/supabase/_sync/client.py index 4ec1ee5e..c9d0d8e9 100644 --- a/supabase/_sync/client.py +++ b/supabase/_sync/client.py @@ -297,9 +297,6 @@ def _listen_to_auth_events( self.options.headers["Authorization"] = self._create_auth_header(access_token) - # set_auth is a coroutine, how to handle this? - self.realtime.set_auth(access_token) - def create_client( supabase_url: str, diff --git a/tests/_async/test_client.py b/tests/_async/test_client.py new file mode 100644 index 00000000..6dc23b9e --- /dev/null +++ b/tests/_async/test_client.py @@ -0,0 +1,38 @@ +import os +from unittest.mock import AsyncMock, MagicMock + +from supabase import create_async_client + + +async def test_updates_the_authorization_header_on_auth_events() -> None: + url = os.environ.get("SUPABASE_TEST_URL") + key = os.environ.get("SUPABASE_TEST_KEY") + + client = await create_async_client(url, key) + + assert client.options.headers.get("apiKey") == key + assert client.options.headers.get("Authorization") == f"Bearer {key}" + + mock_session = MagicMock(access_token="secretuserjwt") + realtime_mock = AsyncMock() + client.realtime = realtime_mock + + client._listen_to_auth_events("SIGNED_IN", mock_session) + + updated_authorization = f"Bearer {mock_session.access_token}" + + assert client.options.headers.get("apiKey") == key + assert client.options.headers.get("Authorization") == updated_authorization + + assert client.postgrest.session.headers.get("apiKey") == key + assert ( + client.postgrest.session.headers.get("Authorization") == updated_authorization + ) + + assert client.auth._headers.get("apiKey") == key + assert client.auth._headers.get("Authorization") == updated_authorization + + assert client.storage.session.headers.get("apiKey") == key + assert client.storage.session.headers.get("Authorization") == updated_authorization + + realtime_mock.set_auth.assert_called_once_with(mock_session.access_token) diff --git a/tests/test_client.py b/tests/test_client.py index d5320b8e..e4c1369b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -92,5 +92,3 @@ def test_updates_the_authorization_header_on_auth_events() -> None: assert client.storage.session.headers.get("apiKey") == key assert client.storage.session.headers.get("Authorization") == updated_authorization - - realtime_mock.set_auth.assert_called_once_with(mock_session.access_token)