diff --git a/supabase/_async/client.py b/supabase/_async/client.py index fed0d1ce..f21cd762 100644 --- a/supabase/_async/client.py +++ b/supabase/_async/client.py @@ -16,7 +16,7 @@ from storage3.constants import DEFAULT_TIMEOUT as DEFAULT_STORAGE_CLIENT_TIMEOUT from supafunc import AsyncFunctionsClient -from supabase.lib.helpers import is_jwt +from supabase.lib.helpers import is_valid_jwt from ..lib.client_options import AsyncClientOptions as ClientOptions from .auth_client import AsyncSupabaseAuthClient @@ -280,7 +280,7 @@ def _create_auth_header(self, token: str): def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]: if authorization is None: - if is_jwt(self.supabase_key): + if is_valid_jwt(self.supabase_key): authorization = self.options.headers.get( "Authorization", self._create_auth_header(self.supabase_key) ) @@ -294,7 +294,9 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st def _listen_to_auth_events( self, event: AuthChangeEvent, session: Optional[Session] ): - default_access_token = self.supabase_key if is_jwt(self.supabase_key) else None + default_access_token = ( + self.supabase_key if is_valid_jwt(self.supabase_key) else None + ) access_token = default_access_token if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]: # reset postgrest and storage instance on event change diff --git a/supabase/_sync/client.py b/supabase/_sync/client.py index 4ba21393..c1f9cdd5 100644 --- a/supabase/_sync/client.py +++ b/supabase/_sync/client.py @@ -15,7 +15,7 @@ from storage3.constants import DEFAULT_TIMEOUT as DEFAULT_STORAGE_CLIENT_TIMEOUT from supafunc import SyncFunctionsClient -from supabase.lib.helpers import is_jwt +from supabase.lib.helpers import is_valid_jwt from ..lib.client_options import SyncClientOptions as ClientOptions from .auth_client import SyncSupabaseAuthClient @@ -279,7 +279,7 @@ def _create_auth_header(self, token: str): def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]: if authorization is None: - if is_jwt(self.supabase_key): + if is_valid_jwt(self.supabase_key): authorization = self.options.headers.get( "Authorization", self._create_auth_header(self.supabase_key) ) @@ -293,7 +293,9 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st def _listen_to_auth_events( self, event: AuthChangeEvent, session: Optional[Session] ): - default_access_token = self.supabase_key if is_jwt(self.supabase_key) else None + default_access_token = ( + self.supabase_key if is_valid_jwt(self.supabase_key) else None + ) access_token = default_access_token if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]: # reset postgrest and storage instance on event change diff --git a/supabase/lib/helpers.py b/supabase/lib/helpers.py index 4fb573ed..0b8b308b 100644 --- a/supabase/lib/helpers.py +++ b/supabase/lib/helpers.py @@ -1,27 +1,41 @@ import re +from typing import Dict BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$" -def is_jwt(value: str) -> bool: - if value.startswith("Bearer "): - value = value.replace("Bearer ", "") +def is_valid_jwt(value: str) -> bool: + """Checks if value looks like a JWT, does not do any extra parsing.""" + if not isinstance(value, str): + return False + # Remove trailing whitespaces if any. value = value.strip() - if not value: - return False - parts = value.split(".") - if len(parts) != 3: + # Remove "Bearer " prefix if any. + if value.startswith("Bearer "): + value = value[7:] + + # Valid JWT must have 2 dots (Header.Paylod.Signature) + if value.count(".") != 2: return False - # loop through the parts and test against regex - for part in parts: - if len(part) < 4 or not re.search(BASE64URL_REGEX, part, re.IGNORECASE): + for part in value.split("."): + if not re.search(BASE64URL_REGEX, part, re.IGNORECASE): return False return True -def check_authorization_header(headers): +def check_authorization_header(headers: Dict[str, str]): + authorization = headers.get("Authorization") + if not authorization: + return + + if authorization.startswith("Bearer "): + if not is_valid_jwt(authorization): + raise ValueError( + "create_client called with global Authorization header that does not contain a JWT" + ) + return True