From 266de35ce9a3706628f266c884e1d13c908c1ea3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 16 Oct 2023 14:25:18 -0400 Subject: [PATCH 1/5] Add support for CountQuery (#65) Adds support for a simple `CountQuery` which expects a `FilterExpression` and allows users to check how many records match a particular set of filters. Also adds documentation and examples for multiple tag fields and count queries. --- docs/user_guide/hybrid_queries_02.ipynb | 41 +++++++- redisvl/index.py | 21 ++-- redisvl/query/__init__.py | 4 +- redisvl/query/query.py | 128 ++++++++++++++---------- tests/integration/test_query.py | 14 ++- 5 files changed, 145 insertions(+), 63 deletions(-) diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index d0185760..7d841f32 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -194,6 +194,19 @@ "result_print(index.query(v))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# multiple tags\n", + "t = Tag(\"credit_score\") == [\"high\", \"medium\"]\n", + "\n", + "v.set_filter(t)\n", + "result_print(index.query(v))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -586,6 +599,32 @@ "result_print(results)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Count Queries\n", + "\n", + "In some cases, you may need to use a ``FilterExpression`` to execute a ``CountQuery`` that simply returns the count of the number of entities in the pertaining set. It is similar to the ``FilterQuery`` class but does not return the values of the underlying data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.query import CountQuery\n", + "\n", + "has_low_credit = Tag(\"credit_score\") == \"low\"\n", + "\n", + "filter_query = CountQuery(filter_expression=has_low_credit)\n", + "\n", + "count = index.query(filter_query)\n", + "\n", + "print(f\"{count} records match the filter expression {str(has_low_credit)} for the given index.\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -846,7 +885,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.12" }, "orig_nbformat": 4, "vscode": { diff --git a/redisvl/index.py b/redisvl/index.py index 9e4b8722..fbeb53ef 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from uuid import uuid4 if TYPE_CHECKING: @@ -10,6 +10,7 @@ import redis from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redisvl.query.query import CountQuery from redisvl.schema import SchemaModel, read_schema from redisvl.utils.connection import ( check_connected, @@ -52,7 +53,7 @@ def client(self) -> redis.Redis: return self._redis_conn # type: ignore @check_connected("_redis_conn") - def search(self, *args, **kwargs) -> List["Result"]: + def search(self, *args, **kwargs) -> Union["Result", Any]: """Perform a search on this index. Wrapper around redis.search.Search that adds the index name @@ -60,9 +61,9 @@ def search(self, *args, **kwargs) -> List["Result"]: to the redis-py ft.search() method. Returns: - List[Result]: A list of search results + Union["Result", Any]: Search results. """ - results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore + results = self._redis_conn.ft(self._name).search( # type: ignore *args, **kwargs ) return results @@ -82,6 +83,8 @@ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: List[Result]: A list of search results. """ results = self.search(query.query, query_params=query.params) + if isinstance(query, CountQuery): + return results.total return process_results(results) @classmethod @@ -522,7 +525,7 @@ async def _load(record: dict): await asyncio.gather(*[_load(record) for record in data]) @check_connected("_redis_conn") - async def search(self, *args, **kwargs) -> List["Result"]: + async def search(self, *args, **kwargs) -> Union["Result", Any]: """Perform a search on this index. Wrapper around redis.search.Search that adds the index name @@ -530,9 +533,11 @@ async def search(self, *args, **kwargs) -> List["Result"]: to the redis-py ft.search() method. Returns: - List[Result]: A list of search results. + Union["Result", Any]: Search results. """ - results: List["Result"] = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore + results = await self._redis_conn.ft(self._name).search( # type: ignore + *args, **kwargs + ) return results async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: @@ -549,6 +554,8 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: List[Result]: A list of search results. """ results = await self.search(query.query, query_params=query.params) + if isinstance(query, CountQuery): + return results.total return process_results(results) @check_connected("_redis_conn") diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index f5bce6eb..16227f22 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -1,3 +1,3 @@ -from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery +from redisvl.query.query import CountQuery, FilterQuery, RangeQuery, VectorQuery -__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"] +__all__ = ["VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 304f228f..59d309f5 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional import numpy as np from redis.commands.search.query import Query @@ -12,6 +12,32 @@ def __init__(self, return_fields: List[str] = [], num_results: int = 10): self._return_fields = return_fields self._num_results = num_results + def __str__(self) -> str: + return " ".join([str(x) for x in self.query.get_args()]) + + def set_filter(self, filter_expression: FilterExpression): + """Set the filter for the query. + + Args: + filter_expression (FilterExpression): The filter to apply to the query. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + if not isinstance(filter_expression, FilterExpression): + raise TypeError( + "filter_expression must be of type redisvl.query.FilterExpression" + ) + self._filter = filter_expression + + def get_filter(self) -> FilterExpression: + """Get the filter for the query. + + Returns: + FilterExpression: The filter for the query. + """ + return self._filter + @property def query(self) -> "Query": raise NotImplementedError @@ -21,6 +47,54 @@ def params(self) -> Dict[str, Any]: raise NotImplementedError +class CountQuery(BaseQuery): + def __init__( + self, + filter_expression: FilterExpression, + params: Optional[Dict[str, Any]] = None, + ): + """Query for a simple count operation on a filter expression. + + Args: + filter_expression (FilterExpression): The filter expression to query for. + params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + + Examples: + >>> from redisvl.query import CountQuery + >>> from redisvl.query.filter import Tag + >>> t = Tag("brand") == "Nike" + >>> q = CountQuery(filter_expression=t) + >>> count = index.query(q) + """ + self.set_filter(filter_expression) + self._params = params + + @property + def query(self) -> Query: + """Return a Redis-Py Query object representing the query. + + Returns: + redis.commands.search.query.Query: The query object. + """ + base_query = str(self._filter) + query = Query(base_query).no_content().dialect(2) + return query + + @property + def params(self) -> Dict[str, Any]: + """Return the parameters for the query. + + Returns: + Dict[str, Any]: The parameters for the query. + """ + if not self._params: + self._params = {} + return self._params + + class FilterQuery(BaseQuery): def __init__( self, @@ -51,32 +125,6 @@ def __init__( self.set_filter(filter_expression) self._params = params - def __str__(self) -> str: - return " ".join([str(x) for x in self.query.get_args()]) - - def set_filter(self, filter_expression: FilterExpression): - """Set the filter for the query. - - Args: - filter_expression (FilterExpression): The filter to apply to the query. - - Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression - """ - if not isinstance(filter_expression, FilterExpression): - raise TypeError( - "filter_expression must be of type redisvl.query.FilterExpression" - ) - self._filter = filter_expression - - def get_filter(self) -> FilterExpression: - """Get the filter for the query. - - Returns: - FilterExpression: The filter for the query. - """ - return self._filter - @property def query(self) -> Query: """Return a Redis-Py Query object representing the query. @@ -127,36 +175,14 @@ def __init__( self._vector = vector self._field = vector_field_name self._dtype = dtype.lower() - self._filter = filter_expression + self._filter = filter_expression # type: ignore + if filter_expression: self.set_filter(filter_expression) if return_score: self._return_fields.append(self.DISTANCE_ID) - def set_filter(self, filter_expression: FilterExpression): - """Set the filter for the query. - - Args: - filter_expression (FilterExpression): The filter to apply to the query. - """ - if not isinstance(filter_expression, FilterExpression): - raise TypeError( - "filter_expression must be of type redisvl.query.FilterExpression" - ) - self._filter = filter_expression - - def get_filter(self) -> Optional[FilterExpression]: - """Get the filter for the query. - - Returns: - Optional[FilterExpression]: The filter for the query. - """ - return self._filter - - def __str__(self): - return " ".join([str(x) for x in self.query.get_args()]) - class VectorQuery(BaseVectorQuery): def __init__( diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index e94718ef..ca59f2fb 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -5,8 +5,8 @@ from redis.commands.search.result import Result from redisvl.index import SearchIndex -from redisvl.query import FilterQuery, RangeQuery, VectorQuery -from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text +from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery +from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text data = [ { @@ -171,6 +171,16 @@ def test_range_query(index): assert len(results) == 2 +def test_count_query(index): + c = CountQuery(FilterExpression("*")) + results = index.query(c) + assert results == len(data) + + c = CountQuery(Tag("credit_score") == "high") + results = index.query(c) + assert results == 4 + + vector_query = VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", From 4aab8b79e127fcb805e38aab424b8ed7bce3c71d Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 18 Oct 2023 10:58:18 -0400 Subject: [PATCH 2/5] Update key prefix and access method (#66) Adds an explicit `key` method help function that combines the index key prefix + in addition to the records identifier to create the key name. --- redisvl/index.py | 29 +++++++++++++++++++++++------ tests/test_index.py | 14 +++++++++++--- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/redisvl/index.py b/redisvl/index.py index fbeb53ef..760a0b76 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -151,7 +151,24 @@ def disconnect(self): """Disconnect from the Redis instance""" self._redis_conn = None - def _get_key(self, record: Dict[str, Any], key_field: Optional[str] = None) -> str: + def key(self, key_value: str) -> str: + """ + Create a redis key as a combination of an index key prefix (optional) and specified key value. + The key value is typically a unique identifier, created at random, or derived from + some specified metadata. + + Args: + key_value (str): The specified unique identifier for a particular document + indexed in Redis. + + Returns: + str: The full Redis key including key prefix and value as a string. + """ + return f"{self._prefix}:{key_value}" if self._prefix else key_value + + def _create_key( + self, record: Dict[str, Any], key_field: Optional[str] = None + ) -> str: """Construct the Redis HASH top level key. Args: @@ -166,13 +183,13 @@ def _get_key(self, record: Dict[str, Any], key_field: Optional[str] = None) -> s ValueError: If the key field is not found in the record. """ if key_field is None: - key = uuid4().hex + key_value = uuid4().hex else: try: - key = record[key_field] # type: ignore + key_value = record[key_field] # type: ignore except KeyError: raise ValueError(f"Key field {key_field} not found in record {record}") - return f"{self._prefix}:{key}" if self._prefix else key + return self.key(key_value) @check_connected("_redis_conn") def info(self) -> Dict[str, Any]: @@ -363,7 +380,7 @@ def load( ttl = kwargs.get("ttl") with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore for record in data: - key = self._get_key(record, key_field) + key = self._create_key(record, key_field) pipe.hset(key, mapping=record) # type: ignore if ttl: pipe.expire(key, ttl) @@ -516,7 +533,7 @@ async def load( async def _load(record: dict): async with semaphore: - key = self._get_key(record, key_field) + key = self._create_key(record, key_field) await self._redis_conn.hset(key, mapping=record) # type: ignore if ttl: await self._redis_conn.expire(key, ttl) # type: ignore diff --git a/tests/test_index.py b/tests/test_index.py index bf71dde4..d31543c4 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -10,14 +10,22 @@ def test_search_index_get_key(): si = SearchIndex("my_index", fields=fields) - key = si._get_key({"id": "foo"}, "id") + key = si.key("foo") assert key.startswith(si._prefix) assert "foo" in key - key = si._get_key({"id": "foo"}) + key = si._create_key({"id": "foo"}) assert key.startswith(si._prefix) assert "foo" not in key +def test_search_index_no_prefix(): + # specify None as the prefix... + si = SearchIndex("my_index", prefix=None, fields=fields) + key = si.key("foo") + assert not si._prefix + assert key == "foo" + + def test_search_index_client(client): si = SearchIndex("my_index", fields=fields) si.set_client(client) @@ -29,6 +37,7 @@ def test_search_index_create(client, redis_url): si = SearchIndex("my_index", fields=fields) si.set_client(client) si.create(overwrite=True) + assert si.exists() assert "my_index" in convert_bytes(si.client.execute_command("FT._LIST")) s1_2 = SearchIndex.from_existing("my_index", url=redis_url) @@ -40,7 +49,6 @@ def test_search_index_delete(client): si.set_client(client) si.create(overwrite=True) si.delete() - assert "my_index" not in convert_bytes(si.client.execute_command("FT._LIST")) From 5b9b0c754d5ec2a85b978143baa72512e6970bfe Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 20 Oct 2023 19:21:41 -0400 Subject: [PATCH 3/5] Improve VectorField schema default args and tests (#68) By default, the `VectorField`'s in Redis do NOT need to have the block size or initial cap args set. This change allows for those params to be set, and only included in the field args if so. Otherwise, they are ignored. Also include a small refactor on the schema classes for vectors as well as new schema unit tests. --- conftest.py | 7 ++ redisvl/schema.py | 46 ++++--- redisvl/vectorize/text/openai.py | 2 +- redisvl/vectorize/text/vertexai.py | 2 +- tests/integration/test_vectorizers.py | 6 +- tests/test_schema.py | 169 ++++++++++++++++++++++++++ 6 files changed, 207 insertions(+), 25 deletions(-) create mode 100644 tests/test_schema.py diff --git a/conftest.py b/conftest.py index 17d16669..144ef9ec 100644 --- a/conftest.py +++ b/conftest.py @@ -28,6 +28,13 @@ def client(): def openai_key(): return os.getenv("OPENAI_API_KEY") +@pytest.fixture +def gcp_location(): + return os.getenv("GCP_LOCATION") + +@pytest.fixture +def gcp_project_id(): + return os.getenv("GCP_PROJECT_ID") @pytest.fixture(scope="session") def event_loop(): diff --git a/redisvl/schema.py b/redisvl/schema.py index 96aced75..c3725fcc 100644 --- a/redisvl/schema.py +++ b/redisvl/schema.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from uuid import uuid4 import yaml @@ -64,48 +64,54 @@ class BaseVectorField(BaseModel): algorithm: object = Field(...) datatype: str = Field(default="FLOAT32") distance_metric: str = Field(default="COSINE") + initial_cap: Optional[int] = None @validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True) def uppercase_strings(cls, v): return v.upper() + def as_field(self) -> Dict[str, Any]: + field_data = { + "TYPE": self.datatype, + "DIM": self.dims, + "DISTANCE_METRIC": self.distance_metric, + } + if self.initial_cap is not None: # Only include it if it's set + field_data["INITIAL_CAP"] = self.initial_cap + return field_data + class FlatVectorField(BaseVectorField): - algorithm: object = Literal["FLAT"] + algorithm: Literal["FLAT"] = "FLAT" + block_size: Optional[int] = None def as_field(self): - return VectorField( - self.name, - self.algorithm, - { - "TYPE": self.datatype, - "DIM": self.dims, - "DISTANCE_METRIC": self.distance_metric, - }, - ) + # grab base field params and augment with flat-specific fields + field_data = super().as_field() + if self.block_size is not None: + field_data["BLOCK_SIZE"] = self.block_size + return VectorField(self.name, self.algorithm, field_data) class HNSWVectorField(BaseVectorField): - algorithm: object = Literal["HNSW"] + algorithm: Literal["HNSW"] = "HNSW" m: int = Field(default=16) ef_construction: int = Field(default=200) ef_runtime: int = Field(default=10) - epsilon: float = Field(default=0.8) + epsilon: float = Field(default=0.01) def as_field(self): - return VectorField( - self.name, - self.algorithm, + # grab base field params and augment with hnsw-specific fields + field_data = super().as_field() + field_data.update( { - "TYPE": self.datatype, - "DIM": self.dims, - "DISTANCE_METRIC": self.distance_metric, "M": self.m, "EF_CONSTRUCTION": self.ef_construction, "EF_RUNTIME": self.ef_runtime, "EPSILON": self.epsilon, - }, + } ) + return VectorField(self.name, self.algorithm, field_data) class IndexModel(BaseModel): diff --git a/redisvl/vectorize/text/openai.py b/redisvl/vectorize/text/openai.py index 1e01b077..b9a83162 100644 --- a/redisvl/vectorize/text/openai.py +++ b/redisvl/vectorize/text/openai.py @@ -36,7 +36,7 @@ def __init__(self, model: str, api_config: Optional[Dict] = None): import openai except ImportError: raise ImportError( - "OpenAI vectorizer requires the openai library. Please install with pip install openai" + "OpenAI vectorizer requires the openai library. Please install with `pip install openai`" ) if not api_config or "api_key" not in api_config: diff --git a/redisvl/vectorize/text/vertexai.py b/redisvl/vectorize/text/vertexai.py index 9298fe9a..a96d3aa9 100644 --- a/redisvl/vectorize/text/vertexai.py +++ b/redisvl/vectorize/text/vertexai.py @@ -48,7 +48,7 @@ def __init__( except ImportError: raise ImportError( "VertexAI vectorizer requires the google-cloud-aiplatform library." - "Please install with pip install google-cloud-aiplatform>=1.26" + "Please install with `pip install google-cloud-aiplatform>=1.26`" ) self._model_client = TextEmbeddingModel.from_pretrained(model) diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index d1af3ef7..e5f7d9b2 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -10,7 +10,7 @@ @pytest.fixture(params=[HFTextVectorizer, OpenAITextVectorizer, VertexAITextVectorizer]) -def vectorizer(request, openai_key): +def vectorizer(request, openai_key, gcp_location, gcp_project_id): # Here we use actual models for integration test if request.param == HFTextVectorizer: return request.param(model="sentence-transformers/all-mpnet-base-v2") @@ -23,8 +23,8 @@ def vectorizer(request, openai_key): return request.param( model="textembedding-gecko", api_config={ - "location": os.environ["GCP_LOCATION"], - "project_id": os.environ["GCP_PROJECT_ID"], + "location": gcp_location, + "project_id": gcp_project_id, }, ) diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 00000000..b8831c78 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,169 @@ +import pytest +from pydantic import ValidationError +from redis.commands.search.field import ( + GeoField, + NumericField, + TagField, + TextField, + VectorField, +) + +from redisvl.schema import ( + FlatVectorField, + GeoFieldSchema, + HNSWVectorField, + NumericFieldSchema, + SchemaModel, + TagFieldSchema, + TextFieldSchema, + read_schema, +) + + +# Utility functions to create schema instances with default values +def create_text_field_schema(**kwargs): + defaults = {"name": "example_textfield", "sortable": False, "weight": 1.0} + defaults.update(kwargs) + return TextFieldSchema(**defaults) + + +def create_tag_field_schema(**kwargs): + defaults = {"name": "example_tagfield", "sortable": False, "separator": ","} + defaults.update(kwargs) + return TagFieldSchema(**defaults) + + +def create_numeric_field_schema(**kwargs): + defaults = {"name": "example_numericfield", "sortable": False} + defaults.update(kwargs) + return NumericFieldSchema(**defaults) + + +def create_geo_field_schema(**kwargs): + defaults = {"name": "example_geofield", "sortable": False} + defaults.update(kwargs) + return GeoFieldSchema(**defaults) + + +def create_flat_vector_field(**kwargs): + defaults = {"name": "example_flatvectorfield", "dims": 128, "algorithm": "FLAT"} + defaults.update(kwargs) + return FlatVectorField(**defaults) + + +def create_hnsw_vector_field(**kwargs): + defaults = { + "name": "example_hnswvectorfield", + "dims": 128, + "algorithm": "HNSW", + "m": 16, + "ef_construction": 200, + "ef_runtime": 10, + "epsilon": 0.01, + } + defaults.update(kwargs) + return HNSWVectorField(**defaults) + + +# Tests for field schema creation and validation +@pytest.mark.parametrize( + "schema_func,field_class", + [ + (create_text_field_schema, TextField), + (create_tag_field_schema, TagField), + (create_numeric_field_schema, NumericField), + (create_geo_field_schema, GeoField), + ], +) +def test_field_schema_as_field(schema_func, field_class): + schema = schema_func() + field = schema.as_field() + assert isinstance(field, field_class) + assert field.name == f"example_{field_class.__name__.lower()}" + + +def test_vector_fields_as_field(): + flat_vector_schema = create_flat_vector_field() + flat_vector_field = flat_vector_schema.as_field() + assert isinstance(flat_vector_field, VectorField) + assert flat_vector_field.name == "example_flatvectorfield" + + hnsw_vector_schema = create_hnsw_vector_field() + hnsw_vector_field = hnsw_vector_schema.as_field() + assert isinstance(hnsw_vector_field, VectorField) + assert hnsw_vector_field.name == "example_hnswvectorfield" + + +@pytest.mark.parametrize( + "vector_schema_func,extra_params", + [ + (create_flat_vector_field, {"block_size": 100}), + (create_hnsw_vector_field, {"m": 24, "ef_construction": 300}), + ], +) +def test_vector_fields_with_optional_params(vector_schema_func, extra_params): + # Create a vector schema with additional parameters set. + vector_schema = vector_schema_func(**extra_params) + vector_field = vector_schema.as_field() + + # Assert that the field is correctly created and the optional parameters are set. + assert isinstance(vector_field, VectorField) + for param, value in extra_params.items(): + assert param.upper() in vector_field.args + i = vector_field.args.index(param.upper()) + assert vector_field.args[i + 1] == value + + +def test_hnsw_vector_field_optional_params_not_set(): + # Create HNSW vector field without setting optional params + hnsw_field = HNSWVectorField(name="example_vector", dims=128, algorithm="HNSW") + + assert hnsw_field.m == 16 # default value + assert hnsw_field.ef_construction == 200 # default value + assert hnsw_field.ef_runtime == 10 # default value + assert hnsw_field.epsilon == 0.01 # default value + + field_exported = hnsw_field.as_field() + + # Check the default values are correctly applied in the exported object + assert field_exported.args[field_exported.args.index("M") + 1] == 16 + assert field_exported.args[field_exported.args.index("EF_CONSTRUCTION") + 1] == 200 + assert field_exported.args[field_exported.args.index("EF_RUNTIME") + 1] == 10 + assert field_exported.args[field_exported.args.index("EPSILON") + 1] == 0.01 + + +def test_flat_vector_field_block_size_not_set(): + # Create Flat vector field without setting block_size + flat_field = FlatVectorField(name="example_vector", dims=128, algorithm="FLAT") + field_exported = flat_field.as_field() + + # block_size and initial_cap should not be in the exported field if it was not set + assert "BLOCK_SIZE" not in field_exported.args + assert "INITIAL_CAP" not in field_exported.args + + +# Test for schema model validation +def test_schema_model_validation_success(): + valid_index = {"name": "test_index", "storage_type": "hash"} + valid_fields = {"text": [create_text_field_schema()]} + schema_model = SchemaModel(index=valid_index, fields=valid_fields) + + assert schema_model.index.name == "test_index" + assert schema_model.index.storage_type == "hash" + assert len(schema_model.fields.text) == 1 + + +def test_schema_model_validation_failures(): + # Invalid storage type + with pytest.raises(ValueError): + invalid_index = {"name": "test_index", "storage_type": "unsupported"} + SchemaModel(index=invalid_index, fields={}) + + # Missing required field + with pytest.raises(ValidationError): + SchemaModel(index={}, fields={}) + + +def test_read_schema_file_not_found(): + with pytest.raises(FileNotFoundError): + read_schema("non_existent_file.yaml") From 6345cc1d3ef1b511cb78c2647035354e27bbf66f Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 20 Oct 2023 19:24:26 -0400 Subject: [PATCH 4/5] Add optional data load preprocessor hook (#67) Another finding from working the arxiv example.. often you need to unpack or edit a record before writing to the source. If you do this before invoking redisvl, you add an additional loop, one that is unnecessary in the end. So this allows devs to optionally add a preprocessor method on load to call against each record. --- redisvl/index.py | 71 ++++++++++++++++++++++++++++++++++++++------- tests/test_index.py | 20 +++++++++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/redisvl/index.py b/redisvl/index.py index 760a0b76..471a184e 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union from uuid import uuid4 if TYPE_CHECKING: @@ -217,13 +217,17 @@ def delete(self, drop: bool = True): Args: drop (bool, optional): Delete the documents in the index. Defaults to True. - raises: + Raises: redis.exceptions.ResponseError: If the index does not exist. """ raise NotImplementedError def load( - self, data: Iterable[Dict[str, Any]], key_field: Optional[str] = None, **kwargs + self, + data: Iterable[Dict[str, Any]], + key_field: Optional[str] = None, + preprocess: Optional[Callable] = None, + **kwargs, ): """Load data into Redis and index using this SearchIndex object. @@ -232,8 +236,10 @@ def load( containing the data to be indexed. key_field (Optional[str], optional): A field within the record to use in the Redis hash key. + preprocess (Optional[Callabl], optional): An optional preprocessor function + that mutates the individual record before writing to redis. - raises: + Raises: redis.exceptions.ResponseError: If the index does not exist. """ raise NotImplementedError @@ -357,7 +363,11 @@ def delete(self, drop: bool = True): @check_connected("_redis_conn") def load( - self, data: Iterable[Dict[str, Any]], key_field: Optional[str] = None, **kwargs + self, + data: Iterable[Dict[str, Any]], + key_field: Optional[str] = None, + preprocess: Optional[Callable] = None, + **kwargs, ): """Load data into Redis and index using this SearchIndex object. @@ -366,9 +376,16 @@ def load( containing the data to be indexed. key_field (Optional[str], optional): A field within the record to use in the Redis hash key. + preprocess (Optional[Callable], optional): An optional preprocessor function + that mutates the individual record before writing to redis. raises: redis.exceptions.ResponseError: If the index does not exist. + + Example: + >>> data = [{"foo": "bar"}, {"test": "values"}] + >>> def func(record: dict): record["new"]="value";return record + >>> index.load(data, preprocess=func) """ # TODO -- should we return a count of the upserts? or some kind of metadata? if data: @@ -381,6 +398,19 @@ def load( with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore for record in data: key = self._create_key(record, key_field) + # Optionally preprocess the record and validate type + if preprocess: + try: + record = preprocess(record) + except Exception as e: + raise RuntimeError( + "Error while preprocessing records on load" + ) from e + if not isinstance(record, dict): + raise TypeError( + f"Individual records must be of type dict, got type {type(record)}" + ) + # Write the record to Redis pipe.hset(key, mapping=record) # type: ignore if ttl: pipe.expire(key, ttl) @@ -406,8 +436,8 @@ class AsyncSearchIndex(SearchIndexBase): Example: >>> from redisvl.index import AsyncSearchIndex >>> index = AsyncSearchIndex.from_yaml("schema.yaml") - >>> index.create(overwrite=True) - >>> index.load(data) # data is an iterable of dictionaries + >>> await index.create(overwrite=True) + >>> await index.load(data) # data is an iterable of dictionaries """ def __init__( @@ -502,7 +532,7 @@ async def delete(self, drop: bool = True): Args: drop (bool, optional): Delete the documents in the index. Defaults to True. - raises: + Raises: redis.exceptions.ResponseError: If the index does not exist. """ # Delete the search index @@ -514,6 +544,7 @@ async def load( data: Iterable[Dict[str, Any]], concurrency: int = 10, key_field: Optional[str] = None, + preprocess: Optional[Callable] = None, **kwargs, ): """Load data into Redis and index using this SearchIndex object. @@ -524,9 +555,16 @@ async def load( concurrency (int, optional): Number of concurrent tasks to run. Defaults to 10. key_field (Optional[str], optional): A field within the record to use in the Redis hash key. + preprocess (Optional[Callable], optional): An optional preprocessor function + that mutates the individual record before writing to redis. - raises: + Raises: redis.exceptions.ResponseError: If the index does not exist. + + Example: + >>> data = [{"foo": "bar"}, {"test": "values"}] + >>> def func(record: dict): record["new"]="value";return record + >>> await index.load(data, preprocess=func) """ ttl = kwargs.get("ttl") semaphore = asyncio.Semaphore(concurrency) @@ -534,11 +572,24 @@ async def load( async def _load(record: dict): async with semaphore: key = self._create_key(record, key_field) + # Optionally preprocess the record and validate type + if preprocess: + try: + record = preprocess(record) + except Exception as e: + raise RuntimeError( + "Error while preprocessing records on load" + ) from e + if not isinstance(record, dict): + raise TypeError( + f"Individual records must be of type dict, got type {type(record)}" + ) + # Write the record to Redis await self._redis_conn.hset(key, mapping=record) # type: ignore if ttl: await self._redis_conn.expire(key, ttl) # type: ignore - # gather with concurrency + # Gather with concurrency await asyncio.gather(*[_load(record) for record in data]) @check_connected("_redis_conn") diff --git a/tests/test_index.py b/tests/test_index.py index d31543c4..733a2835 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -62,6 +62,26 @@ def test_search_index_load(client): assert convert_bytes(client.hget("rvl:1", "value")) == "test" +def test_search_index_load_preprocess(client): + si = SearchIndex("my_index", fields=fields) + si.set_client(client) + si.create(overwrite=True) + data = [{"id": "1", "value": "test"}] + + def preprocess(record): + record["test"] = "foo" + return record + + si.load(data, key_field="id", preprocess=preprocess) + assert convert_bytes(client.hget("rvl:1", "test")) == "foo" + + def bad_preprocess(record): + return 1 + + with pytest.raises(TypeError): + si.load(data, key_field="id", preprocess=bad_preprocess) + + @pytest.mark.asyncio async def test_async_search_index_creation(async_client): asi = AsyncSearchIndex("my_index", fields=fields) From 4e3de2d2dfe3d138f6597a708b1fd97409e95fbb Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 20 Oct 2023 19:58:55 -0400 Subject: [PATCH 5/5] Add tests for token escaper class (#69) Adds a set of unit tests on both the underlying token escaping class as well as the `Tag` filterable fields that utilize it. --------- Co-authored-by: Sam Partee --- redisvl/query/filter.py | 2 +- redisvl/utils/token_escaper.py | 30 ++++++++ redisvl/utils/utils.py | 26 +------ tests/test_filter.py | 57 ++++++++++++--- tests/test_token_escaper.py | 130 +++++++++++++++++++++++++++++++++ 5 files changed, 210 insertions(+), 35 deletions(-) create mode 100644 redisvl/utils/token_escaper.py create mode 100644 tests/test_token_escaper.py diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 8a8c43c5..6e005f4d 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -2,7 +2,7 @@ from functools import wraps from typing import Any, Callable, Dict, List, Optional, Union -from redisvl.utils.utils import TokenEscaper +from redisvl.utils.token_escaper import TokenEscaper # disable mypy error for dunder method overrides # mypy: disable-error-code="override" diff --git a/redisvl/utils/token_escaper.py b/redisvl/utils/token_escaper.py new file mode 100644 index 00000000..10260866 --- /dev/null +++ b/redisvl/utils/token_escaper.py @@ -0,0 +1,30 @@ +import re +from typing import Optional, Pattern + + +class TokenEscaper: + """ + Escape punctuation within an input string. Adapted from RedisOM Python. + """ + + # Characters that RediSearch requires us to escape during queries. + # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization + DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" + + def __init__(self, escape_chars_re: Optional[Pattern] = None): + if escape_chars_re: + self.escaped_chars_re = escape_chars_re + else: + self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) + + def escape(self, value: str) -> str: + if not isinstance(value, str): + raise TypeError( + f"Value must be a string object for token escaping, got type {type(value)}" + ) + + def escape_symbol(match): + value = match.group(0) + return f"\\{value}" + + return self.escaped_chars_re.sub(escape_symbol, value) diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index e4e3adba..f9757f73 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,5 +1,4 @@ -import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern +from typing import TYPE_CHECKING, Any, Dict, List if TYPE_CHECKING: from redis.commands.search.result import Result @@ -8,29 +7,6 @@ import numpy as np -class TokenEscaper: - """ - Escape punctuation within an input string. Taken from RedisOM Python. - """ - - # Characters that RediSearch requires us to escape during queries. - # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization - DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" - - def __init__(self, escape_chars_re: Optional[Pattern] = None): - if escape_chars_re: - self.escaped_chars_re = escape_chars_re - else: - self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) - - def escape(self, value: str) -> str: - def escape_symbol(match): - value = match.group(0) - return f"\\{value}" - - return self.escaped_chars_re.sub(escape_symbol, value) - - def make_dict(values: List[Any]): # TODO make this a real function i = 0 diff --git a/tests/test_filter.py b/tests/test_filter.py index e5106366..7b47a6a0 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -3,15 +3,54 @@ from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text -def test_tag_filter(): - tf = Tag("tag_field") == ["tag1", "tag2"] - assert str(tf) == "@tag_field:{tag1|tag2}" - - tf = Tag("tag_field") == "tag1" - assert str(tf) == "@tag_field:{tag1}" - - tf = Tag("tag_field") != ["tag1", "tag2"] - assert str(tf) == "(-@tag_field:{tag1|tag2})" +# Test cases for various scenarios of tag usage, combinations, and their string representations. +@pytest.mark.parametrize( + "operation,tags,expected", + [ + # Testing single tags + ("==", "simpletag", "@tag_field:{simpletag}"), + ( + "==", + "tag with space", + "@tag_field:{tag\\ with\\ space}", + ), # Escaping spaces within quotes + ( + "==", + "special$char", + "@tag_field:{special\\$char}", + ), # Escaping a special character + ("!=", "negated", "(-@tag_field:{negated})"), + # Testing multiple tags + ("==", ["tag1", "tag2"], "@tag_field:{tag1|tag2}"), + ( + "==", + ["alpha", "beta with space", "gamma$special"], + "@tag_field:{alpha|beta\\ with\\ space|gamma\\$special}", + ), # Multiple tags with spaces and special chars + ("!=", ["tagA", "tagB"], "(-@tag_field:{tagA|tagB})"), + # Complex tag scenarios with special characters + ("==", "weird:tag", "@tag_field:{weird\\:tag}"), # Tags with colon + ("==", "tag&another", "@tag_field:{tag\\&another}"), # Tags with ampersand + # Escaping various special characters within tags + ("==", "tag/with/slashes", "@tag_field:{tag\\/with\\/slashes}"), + ( + "==", + ["hypen-tag", "under_score", "dot.tag"], + "@tag_field:{hypen\\-tag|under_score|dot\\.tag}", + ), + # ...additional unique cases as desired... + ], +) +def test_tag_filter_varied(operation, tags, expected): + if operation == "==": + tf = Tag("tag_field") == tags + elif operation == "!=": + tf = Tag("tag_field") != tags + else: + raise ValueError(f"Unsupported operation: {operation}") + + # Verify the string representation matches the expected RediSearch query part + assert str(tf) == expected def test_numeric_filter(): diff --git a/tests/test_token_escaper.py b/tests/test_token_escaper.py new file mode 100644 index 00000000..def33cbe --- /dev/null +++ b/tests/test_token_escaper.py @@ -0,0 +1,130 @@ +import pytest + +from redisvl.utils.token_escaper import TokenEscaper + + +@pytest.fixture +def escaper(): + return TokenEscaper() + + +@pytest.mark.parametrize( + ("test_input,expected"), + [ + (r"a [big] test.", r"a\ \[big\]\ test\."), + (r"hello, world!", r"hello\,\ world\!"), + ( + r'special "quotes" (and parentheses)', + r"special\ \"quotes\"\ \(and\ parentheses\)", + ), + ( + r"& symbols, like * and ?", + r"\&\ symbols\,\ like\ \*\ and\ ?", + ), # TODO: question marks are not caught? + # underscores are ignored + (r"-dashes_and_underscores-", r"\-dashes_and_underscores\-"), + ], + ids=[ + "brackets", + "commas", + "quotes", + "symbols", + "underscores" + ] +) +def test_escape_text_chars(escaper, test_input, expected): + assert escaper.escape(test_input) == expected + + +@pytest.mark.parametrize( + ("test_input,expected"), + [ + # Simple tags + ("user:name", r"user\:name"), + ("123#comment", r"123\#comment"), + ("hyphen-separated", r"hyphen\-separated"), + # Tags with special characters + ("price$", r"price\$"), + ("super*star", r"super\*star"), + ("tag&value", r"tag\&value"), + ("@username", r"\@username"), + # Space-containing tags often used in search scenarios + ("San Francisco", r"San\ Francisco"), + ("New Zealand", r"New\ Zealand"), + # Multi-special-character tags + ("complex/tag:value", r"complex\/tag\:value"), + ("$special$tag$", r"\$special\$tag\$"), + ("tag-with-hyphen", r"tag\-with\-hyphen"), + # Tags with less common, but legal characters + ("_underscore_", r"_underscore_"), + ("dot.tag", r"dot\.tag"), + # ("pipe|tag", r"pipe\|tag"), #TODO - pipes are not caught? + # More edge cases with special characters + ("(parentheses)", r"\(parentheses\)"), + ("[brackets]", r"\[brackets\]"), + ("{braces}", r"\{braces\}"), + # ("question?mark", r"question\?mark"), #TODO - question marks are not caught? + # Unicode characters in tags + ("你好", r"你好"), # Assuming non-Latin characters don't need escaping + ("emoji:😊", r"emoji\:😊"), + # ...other cases as needed... + ], + ids=[ + ":", + "#", + "-", + "$", + "*", + "&", + "@", + "space", + "space-2", + "complex", + "special", + "hyphen", + "underscore", + "dot", + "parentheses", + "brackets", + "braces", + "non-latin", + "emoji" + ] +) +def test_escape_tag_like_values(escaper, test_input, expected): + assert escaper.escape(test_input) == expected + + +@pytest.mark.parametrize("test_input", [123, 45.67, None, [], {}]) +def test_escape_non_string_input(escaper, test_input): + with pytest.raises(TypeError): + escaper.escape(test_input) + + +@pytest.mark.parametrize( + "test_input,expected", + [ + # ('你好,世界!', r'你好\,世界\!'), # TODO - non latin chars? + ("😊 ❤️ 👍", r"😊\ ❤️\ 👍"), + # ...other cases as needed... + ], + ids=[ + "emoji" + ] +) +def test_escape_unicode_characters(escaper, test_input, expected): + assert escaper.escape(test_input) == expected + + +def test_escape_empty_string(escaper): + assert escaper.escape("") == "" + + +def test_escape_long_string(escaper): + # Construct a very long string + long_str = "a," * 1000 # This creates a string "a,a,a,a,...a," + expected = r"a\," * 1000 # Expected escaped string + + # Use pytest's benchmark fixture to check performance + escaped = escaper.escape(long_str) + assert escaped == expected