From fc945ad8c9e8ed7ffd201e0ff787b329ae29e855 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Fri, 6 Sep 2024 11:47:22 -0700 Subject: [PATCH] feature: support Redis Sentinel URL scheme --- redisvl/redis/connection.py | 88 +++++++++++++++++++++-- tests/integration/test_session_manager.py | 3 +- tests/unit/test_sentinel_url.py | 84 ++++++++++++++++++++++ 3 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_sentinel_url.py diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 9ccc87c8..89b4fef6 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,9 +1,11 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypeVar, Union, overload +from urllib.parse import urlparse from redis import Redis from redis.asyncio import Redis as AsyncRedis from redis.exceptions import ResponseError +from redis.sentinel import Sentinel from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes @@ -145,6 +147,9 @@ def validate_modules( ) +T = TypeVar("T", Redis, AsyncRedis) + + class RedisConnectionFactory: """Builds connections to a Redis database, supporting both synchronous and asynchronous clients. @@ -200,9 +205,15 @@ def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: variable is not set. """ if url: - return Redis.from_url(url, **kwargs) - # fallback to env var REDIS_URL - return Redis.from_url(get_address_from_env(), **kwargs) + if url.startswith("redis+sentinel"): + return RedisConnectionFactory._redis_sentinel_client( + url, Redis, **kwargs + ) + else: + return Redis.from_url(url, **kwargs) + else: + # fallback to env var REDIS_URL + return Redis.from_url(get_address_from_env(), **kwargs) @staticmethod def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedis: @@ -222,9 +233,15 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi variable is not set. """ if url: - return AsyncRedis.from_url(url, **kwargs) - # fallback to env var REDIS_URL - return AsyncRedis.from_url(get_address_from_env(), **kwargs) + if url.startswith("redis+sentinel"): + return RedisConnectionFactory._redis_sentinel_client( + url, AsyncRedis, **kwargs + ) + else: + return AsyncRedis.from_url(url, **kwargs) + else: + # fallback to env var REDIS_URL + return AsyncRedis.from_url(get_address_from_env(), **kwargs) @staticmethod def get_modules(client: Redis) -> Dict[str, Any]: @@ -275,3 +292,60 @@ async def validate_async_redis( # Validate available modules validate_modules(installed_modules, required_modules) + + @staticmethod + @overload + def _redis_sentinel_client( + redis_url: str, redis_class: type[Redis], **kwargs: Any + ) -> Redis: ... + + @staticmethod + @overload + def _redis_sentinel_client( + redis_url: str, redis_class: type[AsyncRedis], **kwargs: Any + ) -> AsyncRedis: ... + + @staticmethod + def _redis_sentinel_client( + redis_url: str, redis_class: Union[type[Redis], type[AsyncRedis]], **kwargs: Any + ) -> Union[Redis, AsyncRedis]: + sentinel_list, service_name, db, username, password = ( + RedisConnectionFactory._parse_sentinel_url(redis_url) + ) + + sentinel_kwargs = {} + if username: + sentinel_kwargs["username"] = username + kwargs["username"] = username + if password: + sentinel_kwargs["password"] = password + kwargs["password"] = password + if db: + kwargs["db"] = db + + sentinel = Sentinel(sentinel_list, sentinel_kwargs=sentinel_kwargs, **kwargs) + return sentinel.master_for(service_name, redis_class=redis_class, **kwargs) + + @staticmethod + def _parse_sentinel_url(url: str) -> tuple: + parsed_url = urlparse(url) + hosts_part = parsed_url.netloc.split("@")[-1] + sentinel_hosts = hosts_part.split(",") + + sentinel_list = [] + for host in sentinel_hosts: + host_parts = host.split(":") + if len(host_parts) == 2: + sentinel_list.append((host_parts[0], int(host_parts[1]))) + else: + sentinel_list.append((host_parts[0], 26379)) + + service_name = "mymaster" + db = None + if parsed_url.path: + path_parts = parsed_url.path.split("/") + service_name = path_parts[1] or "mymaster" + if len(path_parts) > 2: + db = path_parts[2] + + return sentinel_list, service_name, db, parsed_url.username, parsed_url.password diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned diff --git a/tests/unit/test_sentinel_url.py b/tests/unit/test_sentinel_url.py new file mode 100644 index 00000000..dd583253 --- /dev/null +++ b/tests/unit/test_sentinel_url.py @@ -0,0 +1,84 @@ +from unittest.mock import MagicMock, patch + +import pytest +from redis.exceptions import ConnectionError + +from redisvl.redis.connection import RedisConnectionFactory + + +@pytest.mark.parametrize("use_async", [False, True]) +def test_sentinel_url_connection(use_async): + sentinel_url = ( + "redis+sentinel://username:password@host1:26379,host2:26380/mymaster/0" + ) + + with patch("redisvl.redis.connection.Sentinel") as mock_sentinel: + mock_master = MagicMock() + mock_sentinel.return_value.master_for.return_value = mock_master + + if use_async: + client = RedisConnectionFactory.get_async_redis_connection(sentinel_url) + else: + client = RedisConnectionFactory.get_redis_connection(sentinel_url) + + mock_sentinel.assert_called_once() + call_args = mock_sentinel.call_args + assert call_args[0][0] == [("host1", 26379), ("host2", 26380)] + assert call_args[1]["sentinel_kwargs"] == { + "username": "username", + "password": "password", + } + + mock_sentinel.return_value.master_for.assert_called_once() + master_for_args = mock_sentinel.return_value.master_for.call_args + assert master_for_args[0][0] == "mymaster" + assert master_for_args[1]["db"] == "0" + + assert client == mock_master + + +@pytest.mark.parametrize("use_async", [False, True]) +def test_sentinel_url_connection_no_auth_no_db(use_async): + sentinel_url = "redis+sentinel://host1:26379,host2:26380/mymaster" + + with patch("redisvl.redis.connection.Sentinel") as mock_sentinel: + mock_master = MagicMock() + mock_sentinel.return_value.master_for.return_value = mock_master + + if use_async: + client = RedisConnectionFactory.get_async_redis_connection(sentinel_url) + else: + client = RedisConnectionFactory.get_redis_connection(sentinel_url) + + mock_sentinel.assert_called_once() + call_args = mock_sentinel.call_args + assert call_args[0][0] == [("host1", 26379), ("host2", 26380)] + assert ( + "sentinel_kwargs" not in call_args[1] + or call_args[1]["sentinel_kwargs"] == {} + ) + + mock_sentinel.return_value.master_for.assert_called_once() + master_for_args = mock_sentinel.return_value.master_for.call_args + assert master_for_args[0][0] == "mymaster" + assert "db" not in master_for_args[1] + + assert client == mock_master + + +@pytest.mark.parametrize("use_async", [False, True]) +def test_sentinel_url_connection_error(use_async): + sentinel_url = "redis+sentinel://host1:26379,host2:26380/mymaster" + + with patch("redisvl.redis.connection.Sentinel") as mock_sentinel: + mock_sentinel.return_value.master_for.side_effect = ConnectionError( + "Test connection error" + ) + + with pytest.raises(ConnectionError): + if use_async: + RedisConnectionFactory.get_async_redis_connection(sentinel_url) + else: + RedisConnectionFactory.get_redis_connection(sentinel_url) + + mock_sentinel.assert_called_once()