Skip to content

Commit

Permalink
refactor connection check modules
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Feb 7, 2024
1 parent d1442e3 commit e3e107e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def store(
key = cache.store(
prompt="What is the captial city of France?",
response="Paris",
metadata={"city": "Paris", "country": "Fance"}
metadata={"city": "Paris", "country": "France"}
)
"""
# Vectorize prompt if necessary and create cache payload
Expand Down
28 changes: 21 additions & 7 deletions redisvl/redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Any, Dict, List, Optional

from redis import ConnectionPool, Redis
from redis.asyncio import Redis as AsyncRedis
Expand Down Expand Up @@ -101,7 +101,10 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi
return AsyncRedis.from_url(get_address_from_env(), **kwargs)

@staticmethod
def validate_redis_modules(client: Redis) -> None:
def validate_redis_modules(
client: Redis,
redis_required_modules: Optional[List[Dict[str, Any]]] = None
) -> None:
"""Validates if the required Redis modules are installed.
Args:
Expand All @@ -111,11 +114,14 @@ def validate_redis_modules(client: Redis) -> None:
ValueError: If required Redis modules are not installed.
"""
RedisConnectionFactory._validate_redis_modules(
convert_bytes(client.module_list())
convert_bytes(client.module_list()), redis_required_modules
)

@staticmethod
def validate_async_redis_modules(client: AsyncRedis) -> None:
def validate_async_redis_modules(
client: AsyncRedis,
redis_required_modules: Optional[List[Dict[str, Any]]] = None
) -> None:
"""
Validates if the required Redis modules are installed.
Expand All @@ -128,21 +134,29 @@ def validate_async_redis_modules(client: AsyncRedis) -> None:
temp_client = Redis(
connection_pool=ConnectionPool(**client.connection_pool.connection_kwargs)
)
RedisConnectionFactory.validate_redis_modules(temp_client)
RedisConnectionFactory.validate_redis_modules(
temp_client, redis_required_modules
)

@staticmethod
def _validate_redis_modules(installed_modules) -> None:
def _validate_redis_modules(
installed_modules,
redis_required_modules: Optional[List[Dict[str, Any]]] = None
) -> None:
"""
Validates if required Redis modules are installed.
Args:
installed_modules: List of installed modules.
redis_required_modules: List of required modules.
Raises:
ValueError: If required Redis modules are not installed.
"""
installed_modules = {module["name"]: module for module in installed_modules}
for required_module in REDIS_REQUIRED_MODULES:
redis_required_modules = redis_required_modules or REDIS_REQUIRED_MODULES

for required_module in redis_required_modules:
if required_module["name"] in installed_modules:
installed_version = installed_modules[required_module["name"]]["ver"]
if int(installed_version) >= int(required_module["ver"]): # type: ignore
Expand Down

0 comments on commit e3e107e

Please sign in to comment.