Skip to content

Commit

Permalink
Merge branch 'main' into chayim-test-matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Partee authored Nov 1, 2023
2 parents 122841a + 4e3de2d commit 7515d9a
Show file tree
Hide file tree
Showing 17 changed files with 676 additions and 141 deletions.
7 changes: 7 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def skip_vectorizer():
def openai_key():
return os.getenv("OPENAI_API_KEY")

@pytest.fixture
def gcp_location():
return os.getenv("GCP_LOCATION")

@pytest.fixture
def gcp_project_id():
return os.getenv("GCP_PROJECT_ID")

@pytest.fixture(scope="session")
def event_loop():
Expand Down
41 changes: 40 additions & 1 deletion docs/user_guide/hybrid_queries_02.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@
"result_print(index.query(v))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# multiple tags\n",
"t = Tag(\"credit_score\") == [\"high\", \"medium\"]\n",
"\n",
"v.set_filter(t)\n",
"result_print(index.query(v))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -586,6 +599,32 @@
"result_print(results)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Count Queries\n",
"\n",
"In some cases, you may need to use a ``FilterExpression`` to execute a ``CountQuery`` that simply returns the count of the number of entities in the pertaining set. It is similar to the ``FilterQuery`` class but does not return the values of the underlying data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from redisvl.query import CountQuery\n",
"\n",
"has_low_credit = Tag(\"credit_score\") == \"low\"\n",
"\n",
"filter_query = CountQuery(filter_expression=has_low_credit)\n",
"\n",
"count = index.query(filter_query)\n",
"\n",
"print(f\"{count} records match the filter expression {str(has_low_credit)} for the given index.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -846,7 +885,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
119 changes: 97 additions & 22 deletions redisvl/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union
from uuid import uuid4

if TYPE_CHECKING:
Expand All @@ -10,6 +10,7 @@
import redis
from redis.commands.search.indexDefinition import IndexDefinition, IndexType

from redisvl.query.query import CountQuery
from redisvl.schema import SchemaModel, read_schema
from redisvl.utils.connection import (
check_connected,
Expand Down Expand Up @@ -52,17 +53,17 @@ def client(self) -> redis.Redis:
return self._redis_conn # type: ignore

@check_connected("_redis_conn")
def search(self, *args, **kwargs) -> List["Result"]:
def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
Wrapper around redis.search.Search that adds the index name
to the search query and passes along the rest of the arguments
to the redis-py ft.search() method.
Returns:
List[Result]: A list of search results
Union["Result", Any]: Search results.
"""
results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore
results = self._redis_conn.ft(self._name).search( # type: ignore
*args, **kwargs
)
return results
Expand All @@ -82,6 +83,8 @@ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
List[Result]: A list of search results.
"""
results = self.search(query.query, query_params=query.params)
if isinstance(query, CountQuery):
return results.total
return process_results(results)

@classmethod
Expand Down Expand Up @@ -148,7 +151,24 @@ def disconnect(self):
"""Disconnect from the Redis instance"""
self._redis_conn = None

def _get_key(self, record: Dict[str, Any], key_field: Optional[str] = None) -> str:
def key(self, key_value: str) -> str:
"""
Create a redis key as a combination of an index key prefix (optional) and specified key value.
The key value is typically a unique identifier, created at random, or derived from
some specified metadata.
Args:
key_value (str): The specified unique identifier for a particular document
indexed in Redis.
Returns:
str: The full Redis key including key prefix and value as a string.
"""
return f"{self._prefix}:{key_value}" if self._prefix else key_value

def _create_key(
self, record: Dict[str, Any], key_field: Optional[str] = None
) -> str:
"""Construct the Redis HASH top level key.
Args:
Expand All @@ -163,13 +183,13 @@ def _get_key(self, record: Dict[str, Any], key_field: Optional[str] = None) -> s
ValueError: If the key field is not found in the record.
"""
if key_field is None:
key = uuid4().hex
key_value = uuid4().hex
else:
try:
key = record[key_field] # type: ignore
key_value = record[key_field] # type: ignore
except KeyError:
raise ValueError(f"Key field {key_field} not found in record {record}")
return f"{self._prefix}:{key}" if self._prefix else key
return self.key(key_value)

@check_connected("_redis_conn")
def info(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -197,13 +217,17 @@ def delete(self, drop: bool = True):
Args:
drop (bool, optional): Delete the documents in the index. Defaults to True.
raises:
Raises:
redis.exceptions.ResponseError: If the index does not exist.
"""
raise NotImplementedError

def load(
self, data: Iterable[Dict[str, Any]], key_field: Optional[str] = None, **kwargs
self,
data: Iterable[Dict[str, Any]],
key_field: Optional[str] = None,
preprocess: Optional[Callable] = None,
**kwargs,
):
"""Load data into Redis and index using this SearchIndex object.
Expand All @@ -212,8 +236,10 @@ def load(
containing the data to be indexed.
key_field (Optional[str], optional): A field within the record
to use in the Redis hash key.
preprocess (Optional[Callabl], optional): An optional preprocessor function
that mutates the individual record before writing to redis.
raises:
Raises:
redis.exceptions.ResponseError: If the index does not exist.
"""
raise NotImplementedError
Expand Down Expand Up @@ -337,7 +363,11 @@ def delete(self, drop: bool = True):

@check_connected("_redis_conn")
def load(
self, data: Iterable[Dict[str, Any]], key_field: Optional[str] = None, **kwargs
self,
data: Iterable[Dict[str, Any]],
key_field: Optional[str] = None,
preprocess: Optional[Callable] = None,
**kwargs,
):
"""Load data into Redis and index using this SearchIndex object.
Expand All @@ -346,9 +376,16 @@ def load(
containing the data to be indexed.
key_field (Optional[str], optional): A field within the record to
use in the Redis hash key.
preprocess (Optional[Callable], optional): An optional preprocessor function
that mutates the individual record before writing to redis.
raises:
redis.exceptions.ResponseError: If the index does not exist.
Example:
>>> data = [{"foo": "bar"}, {"test": "values"}]
>>> def func(record: dict): record["new"]="value";return record
>>> index.load(data, preprocess=func)
"""
# TODO -- should we return a count of the upserts? or some kind of metadata?
if data:
Expand All @@ -360,7 +397,20 @@ def load(
ttl = kwargs.get("ttl")
with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore
for record in data:
key = self._get_key(record, key_field)
key = self._create_key(record, key_field)
# Optionally preprocess the record and validate type
if preprocess:
try:
record = preprocess(record)
except Exception as e:
raise RuntimeError(
"Error while preprocessing records on load"
) from e
if not isinstance(record, dict):
raise TypeError(
f"Individual records must be of type dict, got type {type(record)}"
)
# Write the record to Redis
pipe.hset(key, mapping=record) # type: ignore
if ttl:
pipe.expire(key, ttl)
Expand All @@ -386,8 +436,8 @@ class AsyncSearchIndex(SearchIndexBase):
Example:
>>> from redisvl.index import AsyncSearchIndex
>>> index = AsyncSearchIndex.from_yaml("schema.yaml")
>>> index.create(overwrite=True)
>>> index.load(data) # data is an iterable of dictionaries
>>> await index.create(overwrite=True)
>>> await index.load(data) # data is an iterable of dictionaries
"""

def __init__(
Expand Down Expand Up @@ -482,7 +532,7 @@ async def delete(self, drop: bool = True):
Args:
drop (bool, optional): Delete the documents in the index. Defaults to True.
raises:
Raises:
redis.exceptions.ResponseError: If the index does not exist.
"""
# Delete the search index
Expand All @@ -494,6 +544,7 @@ async def load(
data: Iterable[Dict[str, Any]],
concurrency: int = 10,
key_field: Optional[str] = None,
preprocess: Optional[Callable] = None,
**kwargs,
):
"""Load data into Redis and index using this SearchIndex object.
Expand All @@ -504,35 +555,57 @@ async def load(
concurrency (int, optional): Number of concurrent tasks to run. Defaults to 10.
key_field (Optional[str], optional): A field within the record to
use in the Redis hash key.
preprocess (Optional[Callable], optional): An optional preprocessor function
that mutates the individual record before writing to redis.
raises:
Raises:
redis.exceptions.ResponseError: If the index does not exist.
Example:
>>> data = [{"foo": "bar"}, {"test": "values"}]
>>> def func(record: dict): record["new"]="value";return record
>>> await index.load(data, preprocess=func)
"""
ttl = kwargs.get("ttl")
semaphore = asyncio.Semaphore(concurrency)

async def _load(record: dict):
async with semaphore:
key = self._get_key(record, key_field)
key = self._create_key(record, key_field)
# Optionally preprocess the record and validate type
if preprocess:
try:
record = preprocess(record)
except Exception as e:
raise RuntimeError(
"Error while preprocessing records on load"
) from e
if not isinstance(record, dict):
raise TypeError(
f"Individual records must be of type dict, got type {type(record)}"
)
# Write the record to Redis
await self._redis_conn.hset(key, mapping=record) # type: ignore
if ttl:
await self._redis_conn.expire(key, ttl) # type: ignore

# gather with concurrency
# Gather with concurrency
await asyncio.gather(*[_load(record) for record in data])

@check_connected("_redis_conn")
async def search(self, *args, **kwargs) -> List["Result"]:
async def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
Wrapper around redis.search.Search that adds the index name
to the search query and passes along the rest of the arguments
to the redis-py ft.search() method.
Returns:
List[Result]: A list of search results.
Union["Result", Any]: Search results.
"""
results: List["Result"] = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore
results = await self._redis_conn.ft(self._name).search( # type: ignore
*args, **kwargs
)
return results

async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
Expand All @@ -549,6 +622,8 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
List[Result]: A list of search results.
"""
results = await self.search(query.query, query_params=query.params)
if isinstance(query, CountQuery):
return results.total
return process_results(results)

@check_connected("_redis_conn")
Expand Down
4 changes: 2 additions & 2 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery
from redisvl.query.query import CountQuery, FilterQuery, RangeQuery, VectorQuery

__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"]
__all__ = ["VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"]
2 changes: 1 addition & 1 deletion redisvl/query/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union

from redisvl.utils.utils import TokenEscaper
from redisvl.utils.token_escaper import TokenEscaper

# disable mypy error for dunder method overrides
# mypy: disable-error-code="override"
Expand Down
Loading

0 comments on commit 7515d9a

Please sign in to comment.