Skip to content

Commit

Permalink
Add support for CountQuery (#65)
Browse files Browse the repository at this point in the history
Adds support for a simple `CountQuery` which expects a
`FilterExpression` and allows users to check how many records match a
particular set of filters. Also adds documentation and examples for
multiple tag fields and count queries.
  • Loading branch information
tylerhutcherson authored Oct 16, 2023
1 parent cf630f0 commit 266de35
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 63 deletions.
41 changes: 40 additions & 1 deletion docs/user_guide/hybrid_queries_02.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@
"result_print(index.query(v))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# multiple tags\n",
"t = Tag(\"credit_score\") == [\"high\", \"medium\"]\n",
"\n",
"v.set_filter(t)\n",
"result_print(index.query(v))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -586,6 +599,32 @@
"result_print(results)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Count Queries\n",
"\n",
"In some cases, you may need to use a ``FilterExpression`` to execute a ``CountQuery`` that simply returns the count of the number of entities in the pertaining set. It is similar to the ``FilterQuery`` class but does not return the values of the underlying data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from redisvl.query import CountQuery\n",
"\n",
"has_low_credit = Tag(\"credit_score\") == \"low\"\n",
"\n",
"filter_query = CountQuery(filter_expression=has_low_credit)\n",
"\n",
"count = index.query(filter_query)\n",
"\n",
"print(f\"{count} records match the filter expression {str(has_low_credit)} for the given index.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -846,7 +885,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
21 changes: 14 additions & 7 deletions redisvl/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from uuid import uuid4

if TYPE_CHECKING:
Expand All @@ -10,6 +10,7 @@
import redis
from redis.commands.search.indexDefinition import IndexDefinition, IndexType

from redisvl.query.query import CountQuery
from redisvl.schema import SchemaModel, read_schema
from redisvl.utils.connection import (
check_connected,
Expand Down Expand Up @@ -52,17 +53,17 @@ def client(self) -> redis.Redis:
return self._redis_conn # type: ignore

@check_connected("_redis_conn")
def search(self, *args, **kwargs) -> List["Result"]:
def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
Wrapper around redis.search.Search that adds the index name
to the search query and passes along the rest of the arguments
to the redis-py ft.search() method.
Returns:
List[Result]: A list of search results
Union["Result", Any]: Search results.
"""
results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore
results = self._redis_conn.ft(self._name).search( # type: ignore
*args, **kwargs
)
return results
Expand All @@ -82,6 +83,8 @@ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
List[Result]: A list of search results.
"""
results = self.search(query.query, query_params=query.params)
if isinstance(query, CountQuery):
return results.total
return process_results(results)

@classmethod
Expand Down Expand Up @@ -522,17 +525,19 @@ async def _load(record: dict):
await asyncio.gather(*[_load(record) for record in data])

@check_connected("_redis_conn")
async def search(self, *args, **kwargs) -> List["Result"]:
async def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
Wrapper around redis.search.Search that adds the index name
to the search query and passes along the rest of the arguments
to the redis-py ft.search() method.
Returns:
List[Result]: A list of search results.
Union["Result", Any]: Search results.
"""
results: List["Result"] = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore
results = await self._redis_conn.ft(self._name).search( # type: ignore
*args, **kwargs
)
return results

async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
Expand All @@ -549,6 +554,8 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
List[Result]: A list of search results.
"""
results = await self.search(query.query, query_params=query.params)
if isinstance(query, CountQuery):
return results.total
return process_results(results)

@check_connected("_redis_conn")
Expand Down
4 changes: 2 additions & 2 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery
from redisvl.query.query import CountQuery, FilterQuery, RangeQuery, VectorQuery

__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"]
__all__ = ["VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"]
128 changes: 77 additions & 51 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import numpy as np
from redis.commands.search.query import Query
Expand All @@ -12,6 +12,32 @@ def __init__(self, return_fields: List[str] = [], num_results: int = 10):
self._return_fields = return_fields
self._num_results = num_results

def __str__(self) -> str:
return " ".join([str(x) for x in self.query.get_args()])

def set_filter(self, filter_expression: FilterExpression):
"""Set the filter for the query.
Args:
filter_expression (FilterExpression): The filter to apply to the query.
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
"""
if not isinstance(filter_expression, FilterExpression):
raise TypeError(
"filter_expression must be of type redisvl.query.FilterExpression"
)
self._filter = filter_expression

def get_filter(self) -> FilterExpression:
"""Get the filter for the query.
Returns:
FilterExpression: The filter for the query.
"""
return self._filter

@property
def query(self) -> "Query":
raise NotImplementedError
Expand All @@ -21,6 +47,54 @@ def params(self) -> Dict[str, Any]:
raise NotImplementedError


class CountQuery(BaseQuery):
def __init__(
self,
filter_expression: FilterExpression,
params: Optional[Dict[str, Any]] = None,
):
"""Query for a simple count operation on a filter expression.
Args:
filter_expression (FilterExpression): The filter expression to query for.
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
Examples:
>>> from redisvl.query import CountQuery
>>> from redisvl.query.filter import Tag
>>> t = Tag("brand") == "Nike"
>>> q = CountQuery(filter_expression=t)
>>> count = index.query(q)
"""
self.set_filter(filter_expression)
self._params = params

@property
def query(self) -> Query:
"""Return a Redis-Py Query object representing the query.
Returns:
redis.commands.search.query.Query: The query object.
"""
base_query = str(self._filter)
query = Query(base_query).no_content().dialect(2)
return query

@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the query.
Returns:
Dict[str, Any]: The parameters for the query.
"""
if not self._params:
self._params = {}
return self._params


class FilterQuery(BaseQuery):
def __init__(
self,
Expand Down Expand Up @@ -51,32 +125,6 @@ def __init__(
self.set_filter(filter_expression)
self._params = params

def __str__(self) -> str:
return " ".join([str(x) for x in self.query.get_args()])

def set_filter(self, filter_expression: FilterExpression):
"""Set the filter for the query.
Args:
filter_expression (FilterExpression): The filter to apply to the query.
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
"""
if not isinstance(filter_expression, FilterExpression):
raise TypeError(
"filter_expression must be of type redisvl.query.FilterExpression"
)
self._filter = filter_expression

def get_filter(self) -> FilterExpression:
"""Get the filter for the query.
Returns:
FilterExpression: The filter for the query.
"""
return self._filter

@property
def query(self) -> Query:
"""Return a Redis-Py Query object representing the query.
Expand Down Expand Up @@ -127,36 +175,14 @@ def __init__(
self._vector = vector
self._field = vector_field_name
self._dtype = dtype.lower()
self._filter = filter_expression
self._filter = filter_expression # type: ignore

if filter_expression:
self.set_filter(filter_expression)

if return_score:
self._return_fields.append(self.DISTANCE_ID)

def set_filter(self, filter_expression: FilterExpression):
"""Set the filter for the query.
Args:
filter_expression (FilterExpression): The filter to apply to the query.
"""
if not isinstance(filter_expression, FilterExpression):
raise TypeError(
"filter_expression must be of type redisvl.query.FilterExpression"
)
self._filter = filter_expression

def get_filter(self) -> Optional[FilterExpression]:
"""Get the filter for the query.
Returns:
Optional[FilterExpression]: The filter for the query.
"""
return self._filter

def __str__(self):
return " ".join([str(x) for x in self.query.get_args()])


class VectorQuery(BaseVectorQuery):
def __init__(
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from redis.commands.search.result import Result

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, RangeQuery, VectorQuery
from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text

data = [
{
Expand Down Expand Up @@ -171,6 +171,16 @@ def test_range_query(index):
assert len(results) == 2


def test_count_query(index):
c = CountQuery(FilterExpression("*"))
results = index.query(c)
assert results == len(data)

c = CountQuery(Tag("credit_score") == "high")
results = index.query(c)
assert results == 4


vector_query = VectorQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
Expand Down

0 comments on commit 266de35

Please sign in to comment.