diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index d0185760..7d841f32 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -194,6 +194,19 @@ "result_print(index.query(v))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# multiple tags\n", + "t = Tag(\"credit_score\") == [\"high\", \"medium\"]\n", + "\n", + "v.set_filter(t)\n", + "result_print(index.query(v))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -586,6 +599,32 @@ "result_print(results)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Count Queries\n", + "\n", + "In some cases, you may need to use a ``FilterExpression`` to execute a ``CountQuery`` that simply returns the count of the number of entities in the pertaining set. It is similar to the ``FilterQuery`` class but does not return the values of the underlying data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.query import CountQuery\n", + "\n", + "has_low_credit = Tag(\"credit_score\") == \"low\"\n", + "\n", + "filter_query = CountQuery(filter_expression=has_low_credit)\n", + "\n", + "count = index.query(filter_query)\n", + "\n", + "print(f\"{count} records match the filter expression {str(has_low_credit)} for the given index.\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -846,7 +885,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.12" }, "orig_nbformat": 4, "vscode": { diff --git a/redisvl/index.py b/redisvl/index.py index 9e4b8722..fbeb53ef 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from uuid import uuid4 if TYPE_CHECKING: @@ -10,6 +10,7 @@ import redis from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redisvl.query.query import CountQuery from redisvl.schema import SchemaModel, read_schema from redisvl.utils.connection import ( check_connected, @@ -52,7 +53,7 @@ def client(self) -> redis.Redis: return self._redis_conn # type: ignore @check_connected("_redis_conn") - def search(self, *args, **kwargs) -> List["Result"]: + def search(self, *args, **kwargs) -> Union["Result", Any]: """Perform a search on this index. Wrapper around redis.search.Search that adds the index name @@ -60,9 +61,9 @@ def search(self, *args, **kwargs) -> List["Result"]: to the redis-py ft.search() method. Returns: - List[Result]: A list of search results + Union["Result", Any]: Search results. """ - results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore + results = self._redis_conn.ft(self._name).search( # type: ignore *args, **kwargs ) return results @@ -82,6 +83,8 @@ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: List[Result]: A list of search results. """ results = self.search(query.query, query_params=query.params) + if isinstance(query, CountQuery): + return results.total return process_results(results) @classmethod @@ -522,7 +525,7 @@ async def _load(record: dict): await asyncio.gather(*[_load(record) for record in data]) @check_connected("_redis_conn") - async def search(self, *args, **kwargs) -> List["Result"]: + async def search(self, *args, **kwargs) -> Union["Result", Any]: """Perform a search on this index. Wrapper around redis.search.Search that adds the index name @@ -530,9 +533,11 @@ async def search(self, *args, **kwargs) -> List["Result"]: to the redis-py ft.search() method. Returns: - List[Result]: A list of search results. + Union["Result", Any]: Search results. """ - results: List["Result"] = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore + results = await self._redis_conn.ft(self._name).search( # type: ignore + *args, **kwargs + ) return results async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: @@ -549,6 +554,8 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: List[Result]: A list of search results. """ results = await self.search(query.query, query_params=query.params) + if isinstance(query, CountQuery): + return results.total return process_results(results) @check_connected("_redis_conn") diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index f5bce6eb..16227f22 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -1,3 +1,3 @@ -from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery +from redisvl.query.query import CountQuery, FilterQuery, RangeQuery, VectorQuery -__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"] +__all__ = ["VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 304f228f..59d309f5 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional import numpy as np from redis.commands.search.query import Query @@ -12,6 +12,32 @@ def __init__(self, return_fields: List[str] = [], num_results: int = 10): self._return_fields = return_fields self._num_results = num_results + def __str__(self) -> str: + return " ".join([str(x) for x in self.query.get_args()]) + + def set_filter(self, filter_expression: FilterExpression): + """Set the filter for the query. + + Args: + filter_expression (FilterExpression): The filter to apply to the query. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + if not isinstance(filter_expression, FilterExpression): + raise TypeError( + "filter_expression must be of type redisvl.query.FilterExpression" + ) + self._filter = filter_expression + + def get_filter(self) -> FilterExpression: + """Get the filter for the query. + + Returns: + FilterExpression: The filter for the query. + """ + return self._filter + @property def query(self) -> "Query": raise NotImplementedError @@ -21,6 +47,54 @@ def params(self) -> Dict[str, Any]: raise NotImplementedError +class CountQuery(BaseQuery): + def __init__( + self, + filter_expression: FilterExpression, + params: Optional[Dict[str, Any]] = None, + ): + """Query for a simple count operation on a filter expression. + + Args: + filter_expression (FilterExpression): The filter expression to query for. + params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + + Examples: + >>> from redisvl.query import CountQuery + >>> from redisvl.query.filter import Tag + >>> t = Tag("brand") == "Nike" + >>> q = CountQuery(filter_expression=t) + >>> count = index.query(q) + """ + self.set_filter(filter_expression) + self._params = params + + @property + def query(self) -> Query: + """Return a Redis-Py Query object representing the query. + + Returns: + redis.commands.search.query.Query: The query object. + """ + base_query = str(self._filter) + query = Query(base_query).no_content().dialect(2) + return query + + @property + def params(self) -> Dict[str, Any]: + """Return the parameters for the query. + + Returns: + Dict[str, Any]: The parameters for the query. + """ + if not self._params: + self._params = {} + return self._params + + class FilterQuery(BaseQuery): def __init__( self, @@ -51,32 +125,6 @@ def __init__( self.set_filter(filter_expression) self._params = params - def __str__(self) -> str: - return " ".join([str(x) for x in self.query.get_args()]) - - def set_filter(self, filter_expression: FilterExpression): - """Set the filter for the query. - - Args: - filter_expression (FilterExpression): The filter to apply to the query. - - Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression - """ - if not isinstance(filter_expression, FilterExpression): - raise TypeError( - "filter_expression must be of type redisvl.query.FilterExpression" - ) - self._filter = filter_expression - - def get_filter(self) -> FilterExpression: - """Get the filter for the query. - - Returns: - FilterExpression: The filter for the query. - """ - return self._filter - @property def query(self) -> Query: """Return a Redis-Py Query object representing the query. @@ -127,36 +175,14 @@ def __init__( self._vector = vector self._field = vector_field_name self._dtype = dtype.lower() - self._filter = filter_expression + self._filter = filter_expression # type: ignore + if filter_expression: self.set_filter(filter_expression) if return_score: self._return_fields.append(self.DISTANCE_ID) - def set_filter(self, filter_expression: FilterExpression): - """Set the filter for the query. - - Args: - filter_expression (FilterExpression): The filter to apply to the query. - """ - if not isinstance(filter_expression, FilterExpression): - raise TypeError( - "filter_expression must be of type redisvl.query.FilterExpression" - ) - self._filter = filter_expression - - def get_filter(self) -> Optional[FilterExpression]: - """Get the filter for the query. - - Returns: - Optional[FilterExpression]: The filter for the query. - """ - return self._filter - - def __str__(self): - return " ".join([str(x) for x in self.query.get_args()]) - class VectorQuery(BaseVectorQuery): def __init__( diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index e94718ef..ca59f2fb 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -5,8 +5,8 @@ from redis.commands.search.result import Result from redisvl.index import SearchIndex -from redisvl.query import FilterQuery, RangeQuery, VectorQuery -from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text +from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery +from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text data = [ { @@ -171,6 +171,16 @@ def test_range_query(index): assert len(results) == 2 +def test_count_query(index): + c = CountQuery(FilterExpression("*")) + results = index.query(c) + assert results == len(data) + + c = CountQuery(Tag("credit_score") == "high") + results = index.query(c) + assert results == 4 + + vector_query = VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding",