From 6a2809b0d10e3b28b204ddb9a15e6103980bff21 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 21 Nov 2023 18:50:22 -0500 Subject: [PATCH] Refactor Redis Storage and add JSON support (#78) PR introduces a new `BaseStorage` class, subclassed by `HashStorage` and `JsonStorage`, to handle the underlying data structures (as well as I/O) for Redis. Includes: - New storage classes and type enums - New `key_separator` param for better use-customizability when constructing redis keys - Updated documentation and new user guide - Updated docstrings and examples --------- Co-authored-by: Sam Partee --- docs/_extension/gallery_directive.py | 12 +- docs/api/searchindex.rst | 27 +- docs/examples/openai_qna.ipynb | 2 +- docs/user_guide/getting_started_01.ipynb | 22 +- docs/user_guide/hash_vs_json_05.ipynb | 519 ++++++++++++ docs/user_guide/hybrid_queries_02.ipynb | 15 +- docs/user_guide/index.md | 3 +- docs/user_guide/llmcache_03.ipynb | 60 +- docs/user_guide/schema.yaml | 2 +- ...torizers_03.ipynb => vectorizers_04.ipynb} | 2 + redisvl/cli/index.py | 24 +- redisvl/cli/log.py | 2 +- redisvl/cli/stats.py | 8 +- redisvl/index.py | 760 +++++++++++------- redisvl/llmcache/base.py | 3 +- redisvl/llmcache/semantic.py | 7 +- redisvl/query/filter.py | 66 +- redisvl/query/query.py | 2 - redisvl/schema.py | 126 ++- redisvl/storage.py | 496 ++++++++++++ redisvl/utils/connection.py | 18 +- redisvl/utils/token_escaper.py | 5 +- redisvl/utils/utils.py | 39 +- redisvl/vectorize/text/huggingface.py | 4 +- redisvl/vectorize/text/openai.py | 7 +- redisvl/vectorize/text/vertexai.py | 7 +- tests/integration/test_query.py | 2 +- tests/integration/test_simple.py | 57 +- tests/sample_hash_schema.yaml | 14 + tests/sample_json_schema.yaml | 16 + tests/unit/test_filter.py | 11 +- tests/unit/test_index.py | 16 +- tests/unit/test_query_types.py | 56 ++ tests/unit/test_schema.py | 70 +- tests/unit/test_storage.py | 80 ++ 35 files changed, 1999 insertions(+), 561 deletions(-) create mode 100644 docs/user_guide/hash_vs_json_05.ipynb rename docs/user_guide/{vectorizers_03.ipynb => vectorizers_04.ipynb} (99%) create mode 100644 redisvl/storage.py create mode 100644 tests/sample_hash_schema.yaml create mode 100644 tests/sample_json_schema.yaml create mode 100644 tests/unit/test_query_types.py create mode 100644 tests/unit/test_storage.py diff --git a/docs/_extension/gallery_directive.py b/docs/_extension/gallery_directive.py index 1e058970..54692158 100644 --- a/docs/_extension/gallery_directive.py +++ b/docs/_extension/gallery_directive.py @@ -1,12 +1,12 @@ """A directive to generate a gallery of images from structured data. -Generating a gallery of images that are all the same size is a common -pattern in documentation, and this can be cumbersome if the gallery is -generated programmatically. This directive wraps this particular use-case -in a helper-directive to generate it with a single YAML configuration file. +Generating a gallery of images that are all the same size is a common pattern in +documentation, and this can be cumbersome if the gallery is generated +programmatically. This directive wraps this particular use-case in a helper- +directive to generate it with a single YAML configuration file. -It currently exists for maintainers of the pydata-sphinx-theme, -but might be abstracted into a standalone package if it proves useful. +It currently exists for maintainers of the pydata-sphinx-theme, but might be +abstracted into a standalone package if it proves useful. """ from pathlib import Path from typing import Any, Dict, List diff --git a/docs/api/searchindex.rst b/docs/api/searchindex.rst index f889489d..972dd251 100644 --- a/docs/api/searchindex.rst +++ b/docs/api/searchindex.rst @@ -13,18 +13,22 @@ SearchIndex .. autosummary:: SearchIndex.__init__ + SearchIndex.client + SearchIndex.name + SearchIndex.prefix + SearchIndex.key_separator + SearchIndex.storage_type SearchIndex.from_yaml SearchIndex.from_dict SearchIndex.from_existing + SearchIndex.connect + SearchIndex.create + SearchIndex.load SearchIndex.search SearchIndex.query - SearchIndex.create SearchIndex.delete - SearchIndex.load - SearchIndex.client - SearchIndex.connect - SearchIndex.disconnect SearchIndex.info + SearchIndex.disconnect @@ -44,17 +48,22 @@ AsyncSearchIndex .. autosummary:: AsyncSearchIndex.__init__ + AsyncSearchIndex.client + AsyncSearchIndex.name + AsyncSearchIndex.prefix + AsyncSearchIndex.key_separator + AsyncSearchIndex.storage_type AsyncSearchIndex.from_yaml AsyncSearchIndex.from_dict AsyncSearchIndex.from_existing + AsyncSearchIndex.connect + AsyncSearchIndex.create + AsyncSearchIndex.load AsyncSearchIndex.search AsyncSearchIndex.query - AsyncSearchIndex.create AsyncSearchIndex.delete - AsyncSearchIndex.load - AsyncSearchIndex.connect - AsyncSearchIndex.disconnect AsyncSearchIndex.info + AsyncSearchIndex.disconnect diff --git a/docs/examples/openai_qna.ipynb b/docs/examples/openai_qna.ipynb index 51e86c21..10141ade 100644 --- a/docs/examples/openai_qna.ipynb +++ b/docs/examples/openai_qna.ipynb @@ -46,7 +46,7 @@ "source": [ "# first we need to install a few things\n", "\n", - "!pip install pandas wget tenacity tiktoken openai" + "!pip install pandas wget tenacity tiktoken openai==0.28.1" ] }, { diff --git a/docs/user_guide/getting_started_01.ipynb b/docs/user_guide/getting_started_01.ipynb index de670bfb..6cc0a47c 100644 --- a/docs/user_guide/getting_started_01.ipynb +++ b/docs/user_guide/getting_started_01.ipynb @@ -127,6 +127,8 @@ "index:\n", " name: user_index\n", " prefix: user\n", + " storage_type: hash\n", + " key_separator: ':'\n", "\n", "fields:\n", " # define tag fields\n", @@ -162,6 +164,8 @@ " \"index\": {\n", " \"name\": \"user_index\",\n", " \"prefix\": \"user\",\n", + " \"storage_type\": \"hash\",\n", + " \"key_separator\": \":\"\n", " },\n", " \"fields\": {\n", " \"tag\": [{\"name\": \"credit_score\"}],\n", @@ -217,8 +221,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m16:03:01\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m16:03:01\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" + "\u001b[32m22:49:46\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m22:49:46\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" ] } ], @@ -465,7 +469,7 @@ ], "source": [ "# create a new SearchIndex instance from an existing index\n", - "existing_index = SearchIndex.from_existing(\"user_index\", \"redis://localhost:6379\")\n", + "existing_index = SearchIndex.from_existing(name=\"user_index\", redis_url=\"redis://localhost:6379\")\n", "\n", "# run the same query\n", "results = existing_index.query(query)\n", @@ -583,7 +587,10 @@ { "data": { "text/plain": [ - "{'index': {'name': 'user_index', 'prefix': 'user'},\n", + "{'index': {'name': 'user_index',\n", + " 'prefix': 'user',\n", + " 'storage_type': 'hash',\n", + " 'key_separator': ':'},\n", " 'fields': {'tag': [{'name': 'credit_score'}],\n", " 'text': [{'name': 'job'}],\n", " 'numeric': [{'name': 'age'}],\n", @@ -612,7 +619,10 @@ { "data": { "text/plain": [ - "{'index': {'name': 'user_index', 'prefix': 'user'},\n", + "{'index': {'name': 'user_index',\n", + " 'prefix': 'user',\n", + " 'storage_type': 'hash',\n", + " 'key_separator': ':'},\n", " 'fields': {'tag': [{'name': 'credit_score'}, {'name': 'job'}],\n", " 'text': [],\n", " 'numeric': [{'name': 'age'}],\n", @@ -725,7 +735,7 @@ "│ offsets_per_term_avg │ 0 │\n", "│ records_per_doc_avg │ 4 │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 0.59 │\n", + "│ total_indexing_time │ 1.738 │\n", "│ total_inverted_index_blocks │ 7 │\n", "│ vector_index_sz_mb │ 0.235603 │\n", "╰─────────────────────────────┴─────────────╯\n" diff --git a/docs/user_guide/hash_vs_json_05.ipynb b/docs/user_guide/hash_vs_json_05.ipynb new file mode 100644 index 00000000..3e2daac3 --- /dev/null +++ b/docs/user_guide/hash_vs_json_05.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hash vs JSON Storage\n", + "\n", + "\n", + "Out of the box, Redis provides a [variety of data structures](https://redis.com/redis-enterprise/data-structures/) that can adapt to your domain specific applications and use cases.\n", + "In this notebook, we will demonstrate how to use RedisVL with both [Hash](https://redis.io/docs/data-types/hashes/) and [JSON](https://redis.io/docs/data-types/json/) data.\n", + "\n", + "\n", + "Before running this notebook, be sure to\n", + "1. Have installed ``redisvl`` and have that environment active for this notebook.\n", + "2. Have a running Redis Stack or Redis Enterprise instance with RediSearch > 2.4 activated.\n", + "\n", + "For example, you can run Redis Stack locally with Docker:\n", + "\n", + "```bash\n", + "docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest\n", + "```\n", + "\n", + "Or create a [FREE Redis Enterprise instance.](https://redis.com/try-free).\n", + "\n", + "This example will assume a local Redis is running on port 6379 and RedisInsight at 8001." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary modules\n", + "import pickle\n", + "from jupyterutils import table_print, result_print\n", + "from redisvl.index import SearchIndex\n", + "\n", + "\n", + "# load in the example data and printing utils\n", + "data = pickle.load(open(\"hybrid_example_data.pkl\", \"rb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
useragejobcredit_scoreoffice_locationuser_embedding
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "table_print(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hash or JSON -- how to choose?\n", + "Both storage options offer a variety of features and tradeoffs. Below we will work through a dummy dataset to learn when and how to use both." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Working with Hashes\n", + "Hashes in Redis are simple collections of field-value pairs. Think of it like a mutable single-level dictionary contains multiple \"rows\":\n", + "\n", + "\n", + "```python\n", + "{\n", + " \"model\": \"Deimos\",\n", + " \"brand\": \"Ergonom\",\n", + " \"type\": \"Enduro bikes\",\n", + " \"price\": 4972,\n", + "}\n", + "```\n", + "\n", + "Hashes are best suited for use cases with the following characteristics:\n", + "- Performance (speed) and storage space (memory consumption) are top concerns\n", + "- Data can be easily normalized and modeled as a single-level dict\n", + "\n", + "> Hashes are typically the default recommendation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# define the hash index schema\n", + "hash_schema = {\n", + " \"index\": {\n", + " \"name\": \"user-hashes\",\n", + " \"storage_type\": \"hash\", # default setting\n", + " \"prefix\": \"hash\",\n", + " \"key_separator\": \":\",\n", + " },\n", + " \"fields\": {\n", + " \"tag\": [{\"name\": \"credit_score\"}, {\"name\": \"user\"}],\n", + " \"text\": [{\"name\": \"job\"}],\n", + " \"numeric\": [{\"name\": \"age\"}],\n", + " \"geo\": [{\"name\": \"office_location\"}],\n", + " \"vector\": [{\n", + " \"name\": \"user_embedding\",\n", + " \"dims\": 3,\n", + " \"distance_metric\": \"cosine\",\n", + " \"algorithm\": \"flat\",\n", + " \"datatype\": \"float32\"}\n", + " ]\n", + " },\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# construct a search index from the hash schema\n", + "hindex = SearchIndex.from_dict(hash_schema)\n", + "\n", + "# connect to local redis instance\n", + "hindex.connect(\"redis://localhost:6379\")\n", + "\n", + "# create the index (no data yet)\n", + "hindex.create(overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# show the underlying storage type\n", + "hindex.storage_type" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Vectors as byte strings\n", + "One nuance when working with Hashes in Redis, is that all vectorized data must be passed as a byte string (for efficient storage, indexing, and processing). An example of that can be seen below:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'user': 'john',\n", + " 'age': 18,\n", + " 'job': 'engineer',\n", + " 'credit_score': 'high',\n", + " 'office_location': '-122.4194,37.7749',\n", + " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# show a single entry from the data that will be loaded\n", + "data[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# load hash data\n", + "hindex.load(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Statistics:\n", + "╭─────────────────────────────┬─────────────╮\n", + "│ Stat Key │ Value │\n", + "├─────────────────────────────┼─────────────┤\n", + "│ num_docs │ 7 │\n", + "│ num_terms │ 6 │\n", + "│ max_doc_id │ 7 │\n", + "│ num_records │ 44 │\n", + "│ percent_indexed │ 1 │\n", + "│ hash_indexing_failures │ 0 │\n", + "│ number_of_uses │ 2 │\n", + "│ bytes_per_record_avg │ 3.40909 │\n", + "│ doc_table_size_mb │ 0.000700951 │\n", + "│ inverted_sz_mb │ 0.000143051 │\n", + "│ key_table_size_mb │ 0.000221252 │\n", + "│ offset_bits_per_record_avg │ 8 │\n", + "│ offset_vectors_sz_mb │ 8.58307e-06 │\n", + "│ offsets_per_term_avg │ 0.204545 │\n", + "│ records_per_doc_avg │ 6.28571 │\n", + "│ sortable_values_size_mb │ 0 │\n", + "│ total_indexing_time │ 0.335 │\n", + "│ total_inverted_index_blocks │ 18 │\n", + "│ vector_index_sz_mb │ 0.0202332 │\n", + "╰─────────────────────────────┴─────────────╯\n" + ] + } + ], + "source": [ + "!rvl stats -i user-hashes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Performing Queries\n", + "Once our index is created and data is loaded into the right format, we can run queries against the index with RedisVL:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query import VectorQuery\n", + "from redisvl.query.filter import Tag, Text, Num\n", + "\n", + "t = (Tag(\"credit_score\") == \"high\") & (Text(\"job\") % \"enginee*\") & (Num(\"age\") > 17)\n", + "\n", + "v = VectorQuery([0.1, 0.1, 0.5],\n", + " \"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " filter_expression=t)\n", + "\n", + "\n", + "results = hindex.query(v)\n", + "result_print(results)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Working with JSON\n", + "Redis also supports native **JSON** objects. These can be multi-level (nested) objects, with full JSONPath support for updating/retrieving sub elements:\n", + "\n", + "```python\n", + "{\n", + " \"name\": \"bike\",\n", + " \"metadata\": {\n", + " \"model\": \"Deimos\",\n", + " \"brand\": \"Ergonom\",\n", + " \"type\": \"Enduro bikes\",\n", + " \"price\": 4972,\n", + " }\n", + "}\n", + "```\n", + "\n", + "JSON is best suited for use cases with the following characteristics:\n", + "- Ease of use and data model flexibility are top concerns\n", + "- Application data is already native JSON\n", + "- Replacing another document storage/db solution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Full JSON Path support\n", + "Because RedisJSON enables full path support, when creating an index schema, elements need to be indexed and selected by their path with the `name` param and aliased using the `as_name` param as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# define the json index schema\n", + "json_schema = {\n", + " \"index\": {\n", + " \"name\": \"user-json\",\n", + " \"storage_type\": \"json\", # updated storage_type option\n", + " \"prefix\": \"json\",\n", + " \"key_separator\": \":\",\n", + " },\n", + " \"fields\": {\n", + " \"tag\": [{\"name\": \"$.credit_score\", \"as_name\": \"credit_score\"}, {\"name\": \"$.user\", \"as_name\": \"user\"}],\n", + " \"text\": [{\"name\": \"$.job\", \"as_name\": \"job\"}],\n", + " \"numeric\": [{\"name\": \"$.age\", \"as_name\": \"age\"}],\n", + " \"geo\": [{\"name\": \"$.office_location\", \"as_name\": \"office_location\"}],\n", + " \"vector\": [{\n", + " \"name\": \"$.user_embedding\",\n", + " \"as_name\": \"user_embedding\",\n", + " \"dims\": 3,\n", + " \"distance_metric\": \"cosine\",\n", + " \"algorithm\": \"flat\",\n", + " \"datatype\": \"float32\"}\n", + " ]\n", + " },\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# construct a search index from the json schema\n", + "jindex = SearchIndex.from_dict(json_schema)\n", + "\n", + "# connect to local redis instance\n", + "jindex.connect(\"redis://localhost:6379\")\n", + "\n", + "# create the index (no data yet)\n", + "jindex.create(overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m22:50:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m22:50:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user-hashes\n", + "\u001b[32m22:50:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. user-json\n" + ] + } + ], + "source": [ + "# note the multiple indices in the same database\n", + "!rvl index listall" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Vectors as float arrays\n", + "Vectorized data stored in JSON must be stored as a pure array (python list) of floats. We will modify our sample data to account for this below:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "json_data = data.copy()\n", + "\n", + "for d in json_data:\n", + " d['user_embedding'] = np.frombuffer(d['user_embedding'], dtype=np.float32).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'user': 'john',\n", + " 'age': 18,\n", + " 'job': 'engineer',\n", + " 'credit_score': 'high',\n", + " 'office_location': '-122.4194,37.7749',\n", + " 'user_embedding': [0.10000000149011612, 0.10000000149011612, 0.5]}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# inspect a single JSON record\n", + "json_data[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "jindex.load(json_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# we can now run the exact same query as above\n", + "result_print(jindex.query(v))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "hindex.delete()\n", + "jindex.delete()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('redisvl2')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index 9443cd53..3daa48ef 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -51,6 +51,8 @@ " \"index\": {\n", " \"name\": \"user_index\",\n", " \"prefix\": \"v1\",\n", + " \"storage_type\": \"hash\",\n", + " \"key_separator\": \":\"\n", " },\n", " \"fields\": {\n", " \"tag\": [{\"name\": \"credit_score\"}],\n", @@ -95,8 +97,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m16:03:26\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m16:03:26\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" + "\u001b[32m22:51:08\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m22:51:08\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" ] } ], @@ -731,7 +733,6 @@ "metadata": {}, "outputs": [], "source": [ - "#\n", "def make_filter(age=None, credit=None, job=None):\n", " flexible_filter = (\n", " (Num(\"age\") > age) &\n", @@ -1107,10 +1108,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'id': 'v1:54b273392e4d4fa2af424caca095d2d4', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'v1:abdab0c48bed49bea9a79d9eb3f247fa', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'v1:81ea678467be4ca1bd8efaec27766d10', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'v1:44741013d4d5469dad4b95f70cedc0bb', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'v1:13dbcb6b63e6416187a8c9ee1ab6eae7', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'v1:02d544f7543a40c780dee81116dd5610', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'v1:d521d5c1778842e98d8ad50d837a60a4', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'v1:2efe1220f62a4f8fb94055de526ff8f6', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index d5d88ce7..beb46718 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -13,7 +13,8 @@ myst: getting_started_01 hybrid_queries_02 -vectorizers_03 llmcache_03 +vectorizers_04 +hash_vs_json_05 ``` diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb index 557b8d10..5d4d236b 100644 --- a/docs/user_guide/llmcache_03.ipynb +++ b/docs/user_guide/llmcache_03.ipynb @@ -20,15 +20,15 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import openai\n", "import getpass\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"False\"\n", "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"False\"\n", "\n", "api_key = os.getenv(\"OPENAI_API_KEY\") or getpass.getpass(\"Enter your OpenAI API key: \")\n", "\n", @@ -45,14 +45,14 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Paris\n" + "Paris.\n" ] } ], @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -80,12 +80,12 @@ "cache = SemanticCache(\n", " redis_url=\"redis://localhost:6379\",\n", " threshold=0.9, # semantic similarity threshold\n", - " )" + ")" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -125,7 +125,7 @@ "[]" ] }, - "execution_count": 20, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -137,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -156,7 +156,7 @@ "['Paris']" ] }, - "execution_count": 22, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -168,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -177,7 +177,7 @@ "[]" ] }, - "execution_count": 23, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -198,7 +198,7 @@ "['Paris']" ] }, - "execution_count": 24, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -211,7 +211,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -220,7 +220,7 @@ "[]" ] }, - "execution_count": 25, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -232,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -250,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -266,14 +266,14 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Time taken without cache 0.8732700347900391\n" + "Time taken without cache 0.574105978012085\n" ] } ], @@ -287,15 +287,15 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Time Taken with cache: 0.04746699333190918\n", - "Percentage of time saved: 94.56%\n" + "Time Taken with cache: 0.09868717193603516\n", + "Percentage of time saved: 82.81%\n" ] } ], @@ -309,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -337,9 +337,9 @@ "│ offsets_per_term_avg │ 0 │\n", "│ records_per_doc_avg │ 2 │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 0.211 │\n", + "│ total_indexing_time │ 0.087 │\n", "│ total_inverted_index_blocks │ 11 │\n", - "│ vector_index_sz_mb │ 3.00814 │\n", + "│ vector_index_sz_mb │ 3.0161 │\n", "╰─────────────────────────────┴─────────────╯\n" ] } @@ -351,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -376,7 +376,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.12" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/schema.yaml b/docs/user_guide/schema.yaml index c96ac23f..ed6f9f9b 100644 --- a/docs/user_guide/schema.yaml +++ b/docs/user_guide/schema.yaml @@ -1,8 +1,8 @@ - index: name: providers prefix: rvl storage_type: hash + key_separator: ':' fields: text: diff --git a/docs/user_guide/vectorizers_03.ipynb b/docs/user_guide/vectorizers_04.ipynb similarity index 99% rename from docs/user_guide/vectorizers_03.ipynb rename to docs/user_guide/vectorizers_04.ipynb index 183cbc58..4caf1964 100644 --- a/docs/user_guide/vectorizers_03.ipynb +++ b/docs/user_guide/vectorizers_04.ipynb @@ -271,6 +271,8 @@ "index:\n", " name: providers\n", " prefix: rvl\n", + " storage_type: hash\n", + " key_separator: ':'\n", "\n", "fields:\n", " text:\n", diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py index 3e795224..f1f5c16d 100644 --- a/redisvl/cli/index.py +++ b/redisvl/cli/index.py @@ -51,7 +51,7 @@ def __init__(self): exit(0) def create(self, args: Namespace): - """Create an index + """Create an index. Usage: rvl index create -i | -s @@ -59,13 +59,13 @@ def create(self, args: Namespace): if not args.schema: logger.error("Schema must be provided to create an index") index = SearchIndex.from_yaml(args.schema) - url = create_redis_url(args) - index.connect(url) + redis_url = create_redis_url(args) + index.connect(redis_url) index.create() logger.info("Index created successfully") def info(self, args: Namespace): - """Obtain information about an index + """Obtain information about an index. Usage: rvl index info -i | -s @@ -74,20 +74,20 @@ def info(self, args: Namespace): _display_in_table(index.info(), output_format=args.format) def listall(self, args: Namespace): - """List all indices + """List all indices. Usage: rvl index listall """ - url = create_redis_url(args) - conn = get_redis_connection(url) + redis_url = create_redis_url(args) + conn = get_redis_connection(redis_url) indices = convert_bytes(conn.execute_command("FT._LIST")) logger.info("Indices:") for i, index in enumerate(indices): logger.info(str(i + 1) + ". " + index) def delete(self, args: Namespace, drop=False): - """Delete an index + """Delete an index. Usage: rvl index delete -i | -s @@ -97,7 +97,7 @@ def delete(self, args: Namespace, drop=False): logger.info("Index deleted successfully") def destroy(self, args: Namespace): - """Delete an index and the documents within it + """Delete an index and the documents within it. Usage: rvl index destroy -i | -s @@ -107,8 +107,8 @@ def destroy(self, args: Namespace): def _connect_to_index(self, args: Namespace) -> SearchIndex: # connect to redis try: - url = create_redis_url(args) - conn = get_redis_connection(url=url) + redis_url = create_redis_url(args) + conn = get_redis_connection(url=redis_url) except ValueError: logger.error( "Must set REDIS_URL environment variable or provide host and port" @@ -116,7 +116,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex: exit(0) if args.index: - index = SearchIndex.from_existing(name=args.index, url=url) + index = SearchIndex.from_existing(name=args.index, redis_url=redis_url) elif args.schema: index = SearchIndex.from_yaml(args.schema) index.set_client(conn) diff --git a/redisvl/cli/log.py b/redisvl/cli/log.py index 145a76c5..41d5fcb1 100644 --- a/redisvl/cli/log.py +++ b/redisvl/cli/log.py @@ -9,7 +9,7 @@ def get_logger(name, log_level="info", fmt=None): - """Return a logger instance""" + """Return a logger instance.""" # Use file name if logger is in debug mode name = "RedisVL" if log_level == "debug" else name diff --git a/redisvl/cli/stats.py b/redisvl/cli/stats.py index 9861f9e3..27a2b9de 100644 --- a/redisvl/cli/stats.py +++ b/redisvl/cli/stats.py @@ -56,7 +56,7 @@ def __init__(self): exit(0) def stats(self, args: Namespace): - """Obtain stats about an index + """Obtain stats about an index. Usage: rvl stats -i | -s @@ -67,8 +67,8 @@ def stats(self, args: Namespace): def _connect_to_index(self, args: Namespace) -> SearchIndex: # connect to redis try: - url = create_redis_url(args) - conn = get_redis_connection(url=url) + redis_url = create_redis_url(args) + conn = get_redis_connection(url=redis_url) except ValueError: logger.error( "Must set REDIS_ADDRESS environment variable or provide host and port" @@ -76,7 +76,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex: exit(0) if args.index: - index = SearchIndex.from_existing(name=args.index, url=url) + index = SearchIndex.from_existing(name=args.index, redis_url=redis_url) elif args.schema: index = SearchIndex.from_yaml(args.schema) index.set_client(conn) diff --git a/redisvl/index.py b/redisvl/index.py index 61eef279..03109d01 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -1,91 +1,243 @@ -import asyncio +import json +from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union -from uuid import uuid4 if TYPE_CHECKING: from redis.commands.search.field import Field + from redis.commands.search.document import Document from redis.commands.search.result import Result from redisvl.query.query import BaseQuery import redis -from redis.commands.search.indexDefinition import IndexDefinition, IndexType - -from redisvl.query.query import CountQuery -from redisvl.schema import SchemaModel, read_schema -from redisvl.utils.connection import ( - check_connected, - get_async_redis_connection, - get_redis_connection, -) +from redis.commands.search.indexDefinition import IndexDefinition + +from redisvl.query.query import BaseQuery, CountQuery, FilterQuery +from redisvl.schema import SchemaModel, StorageType, read_schema +from redisvl.storage import HashStorage, JsonStorage +from redisvl.utils.connection import get_async_redis_connection, get_redis_connection from redisvl.utils.utils import ( + check_async_redis_modules_exist, check_redis_modules_exist, convert_bytes, make_dict, - process_results, ) +def process_results( + results: "Result", query: BaseQuery, storage_type: StorageType +) -> List[Dict[str, Any]]: + """Convert a list of search Result objects into a list of document + dictionaries. + + This function processes results from Redis, handling different storage + types and query types. For JSON storage with empty return fields, it + unpacks the JSON object while retaining the document ID. The 'payload' + field is also removed from all resulting documents for consistency. + + Args: + results (Result): The search results from Redis. + query (BaseQuery): The query object used for the search. + storage_type (StorageType): The storage type of the search + index (json or hash). + + Returns: + List[Dict[str, Any]]: A list of processed document dictionaries. + """ + # Handle count queries + if isinstance(query, CountQuery): + return results.total + + # Determine if unpacking JSON is needed + unpack_json = ( + (storage_type == StorageType.JSON) + and isinstance(query, FilterQuery) + and not query._return_fields + ) + + # Process records + def _process(doc: "Document") -> Dict[str, Any]: + doc_dict = doc.__dict__ + + # Unpack and Project JSON fields properly + if unpack_json and "json" in doc_dict: + json_data = doc_dict.get("json", {}) + if isinstance(json_data, str): + json_data = json.loads(json_data) + if isinstance(json_data, dict): + return {"id": doc_dict.get("id"), **json_data} + raise ValueError(f"Unable to parse json data from Redis {json_data}") + + # Remove 'payload' if present + doc_dict.pop("payload", None) + + return doc_dict + + return [_process(doc) for doc in results.docs] + + +def check_modules_present(client_variable_name: str): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + client = getattr(self, client_variable_name) + check_redis_modules_exist(client) + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def check_async_modules_present(client_variable_name: str): + def decorator(func): + @wraps(func) + async def wrapper(self, *args, **kwargs): + client = getattr(self, client_variable_name) + await check_async_redis_modules_exist(client) + return await func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def check_index_exists(): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.exists(): + raise ValueError( + f"Index has not been created. Must be created before calling {func.__name__}" + ) + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def check_async_index_exists(): + def decorator(func): + @wraps(func) + async def wrapper(self, *args, **kwargs): + if not await self.exists(): + raise ValueError( + f"Index has not been created. Must be created before calling {func.__name__}" + ) + return await func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def check_connected(client_variable_name: str): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if getattr(self, client_variable_name) is None: + raise ValueError( + f"SearchIndex.connect() must be called before calling {func.__name__}" + ) + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def check_async_connected(client_variable_name: str): + def decorator(func): + @wraps(func) + async def wrapper(self, *args, **kwargs): + if getattr(self, client_variable_name) is None: + raise ValueError( + f"SearchIndex.connect() must be called before calling {func.__name__}" + ) + return await func(self, *args, **kwargs) + + return wrapper + + return decorator + + class SearchIndexBase: + STORAGE_MAP = { + StorageType.HASH: HashStorage, + StorageType.JSON: JsonStorage, + } + def __init__( self, name: str, prefix: str = "rvl", storage_type: str = "hash", + key_separator: str = ":", fields: Optional[List["Field"]] = None, + **kwargs, ): + """Initialize the RedisVL search index class. + + Args: + name (str): Index name. + prefix (str, optional): Key prefix associated with the index. + Defaults to "rvl". + storage_type (str, optional): Underlying Redis storage type (hash + or json). Defaults to "hash". + key_separator (str, optional): : Separator character to combine + prefix and key value for constructing redis keys. + Defaults to ":". + fields (Optional[List[Field]], optional): List of Redis fields to + index. Defaults to None. + """ self._name = name self._prefix = prefix - self._storage = storage_type + self._key_separator = key_separator + self._storage_type = StorageType(storage_type) self._fields = fields + + # configure storage layer + self._storage = self.STORAGE_MAP[self._storage_type]( # type: ignore + self._prefix, self._key_separator + ) + + # init empty redis conn self._redis_conn: Optional[redis.Redis] = None + if "redis_url" in kwargs: + redis_url = kwargs.pop("redis_url") + self.connect(redis_url, **kwargs) - def set_client(self, client: redis.Redis): + def set_client(self, client: redis.Redis) -> None: + """Set the Redis client object for the search index.""" self._redis_conn = client @property - @check_connected("_redis_conn") - def client(self) -> redis.Redis: - """The redis-py client object. - - Returns: - redis.Redis: The redis-py client object - """ - return self._redis_conn # type: ignore + def name(self) -> str: + """The name of the Redis search index.""" + return self._name - @check_connected("_redis_conn") - def search(self, *args, **kwargs) -> Union["Result", Any]: - """Perform a search on this index. + @property + def prefix(self) -> str: + """The optional key prefix that comes before a unique key value in + forming a Redis key.""" + return self._prefix - Wrapper around redis.search.Search that adds the index name - to the search query and passes along the rest of the arguments - to the redis-py ft.search() method. + @property + def key_separator(self) -> str: + """The optional separator between a defined prefix and key value in + forming a Redis key.""" + return self._key_separator - Returns: - Union["Result", Any]: Search results. - """ - results = self._redis_conn.ft(self._name).search( # type: ignore - *args, **kwargs - ) - return results + @property + def storage_type(self) -> StorageType: + """The underlying storage type for the search index: hash or json.""" + return self._storage_type + @property @check_connected("_redis_conn") - def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: - """Run a query on this index. - - This is similar to the search method, but takes a BaseQuery - object directly (does not allow for the usage of a raw - redis query string) and post-processes results of the search. - - Args: - query (BaseQuery): The query to run. - - Returns: - List[Result]: A list of search results. - """ - results = self.search(query.query, query_params=query.params) - if isinstance(query, CountQuery): - return results.total - return process_results(results) + def client(self) -> redis.Redis: + """The underlying redis-py client object.""" + return self._redis_conn # type: ignore @classmethod def from_yaml(cls, schema_path: str): @@ -94,6 +246,12 @@ def from_yaml(cls, schema_path: str): Args: schema_path (str): Path to the YAML schema file. + Example: + >>> from redisvl.index import SearchIndex + >>> index = SearchIndex.from_yaml("schema.yaml") + >>> index.connect("redis://localhost:6379") + >>> index.create(overwrite=True) + Returns: SearchIndex: A SearchIndex object. """ @@ -107,6 +265,22 @@ def from_dict(cls, schema_dict: Dict[str, Any]): Args: schema_dict (Dict[str, Any]): A dictionary containing the schema. + Example: + >>> from redisvl.index import SearchIndex + >>> index = SearchIndex.from_dict({ + >>> "index": { + >>> "name": "my-index", + >>> "prefix": "rvl", + >>> "storage_type": "hash", + >>> "key_separator": ":" + >>> }, + >>> "fields": { + >>> "tag": [{"name": "doc-id"}] + >>> } + >>> }) + >>> index.connect("redis://localhost:6379") + >>> index.create(overwrite=True) + Returns: SearchIndex: A SearchIndex object. """ @@ -117,132 +291,77 @@ def from_dict(cls, schema_dict: Dict[str, Any]): def from_existing( cls, name: str, - url: Optional[str] = None, + redis_url: Optional[str] = None, + key_separator: str = ":", fields: Optional[List["Field"]] = None, **kwargs, ): - """Create a SearchIndex from an existing index in Redis. - - Args: - name (str): Index name. - url (Optional[str], optional): Redis URL. REDIS_URL env var - is used if not provided. Defaults to None. - fields (Optional[List[Field]], optional): List of Redis search - fields to include in the schema. Defaults to None. - - Returns: - SearchIndex: A SearchIndex object. + raise NotImplementedError - Raises: - redis.exceptions.ResponseError: If the index does not exist. - ValueError: If the REDIS_URL env var is not set and url is not provided. - """ + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() + def search(self, *args, **kwargs) -> Union["Result", Any]: raise NotImplementedError - def connect(self, url: str, **kwargs): - """Connect to a Redis instance. + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() + def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: + raise NotImplementedError - Args: - url (str): Redis URL. REDIS_URL env var is used if not provided. - """ + def connect(self, redis_url: Optional[str] = None, **kwargs): + """Connect to a Redis instance.""" raise NotImplementedError def disconnect(self): - """Disconnect from the Redis instance""" + """Disconnect from the Redis instance.""" self._redis_conn = None return self def key(self, key_value: str) -> str: - """ - Create a redis key as a combination of an index key prefix (optional) and specified key value. - The key value is typically a unique identifier, created at random, or derived from - some specified metadata. + """Create a redis key as a combination of an index key prefix (optional) + and specified key value. The key value is typically a unique identifier, + created at random, or derived from some specified metadata. Args: - key_value (str): The specified unique identifier for a particular document - indexed in Redis. + key_value (str): The specified unique identifier for a particular + document indexed in Redis. Returns: str: The full Redis key including key prefix and value as a string. """ - return f"{self._prefix}:{key_value}" if self._prefix else key_value - - def _create_key( - self, record: Dict[str, Any], key_field: Optional[str] = None - ) -> str: - """Construct the Redis HASH top level key. - - Args: - record (Dict[str, Any]): A dictionary containing the record to be indexed. - key_field (Optional[str], optional): A field within the record - to use in the Redis hash key. - - Returns: - str: The key to be used for a given record in Redis. - - Raises: - ValueError: If the key field is not found in the record. - """ - if key_field is None: - key_value = uuid4().hex - else: - try: - key_value = record[key_field] # type: ignore - except KeyError: - raise ValueError(f"Key field {key_field} not found in record {record}") - return self.key(key_value) + return self._storage._key(key_value, self._prefix, self._key_separator) @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() def info(self) -> Dict[str, Any]: - """Get information about the index. - - Returns: - dict: A dictionary containing the information about the index. - """ - return convert_bytes(self._redis_conn.ft(self._name).info()) # type: ignore - - def create(self, overwrite: Optional[bool] = False): - """Create an index in Redis from this SearchIndex object. - - Args: - overwrite (bool, optional): Overwrite the index if it already exists. Defaults to False. + raise NotImplementedError - Raises: - redis.exceptions.ResponseError: If the index already exists. - """ + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + def create(self, overwrite: bool = False): raise NotImplementedError + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() def delete(self, drop: bool = True): - """Delete the search index. - - Args: - drop (bool, optional): Delete the documents in the index. Defaults to True. - - Raises: - redis.exceptions.ResponseError: If the index does not exist. - """ raise NotImplementedError + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") def load( self, - data: Iterable[Dict[str, Any]], + data: Iterable[Any], key_field: Optional[str] = None, + keys: Optional[Iterable[str]] = None, + ttl: Optional[int] = None, preprocess: Optional[Callable] = None, + concurrency: Optional[int] = None, **kwargs, ): - """Load data into Redis and index using this SearchIndex object. - - Args: - data (Iterable[Dict[str, Any]]): An iterable of dictionaries - containing the data to be indexed. - key_field (Optional[str], optional): A field within the record - to use in the Redis hash key. - preprocess (Optional[Callabl], optional): An optional preprocessor function - that mutates the individual record before writing to redis. - - Raises: - redis.exceptions.ResponseError: If the index does not exist. - """ raise NotImplementedError @@ -255,24 +374,17 @@ class SearchIndex(SearchIndexBase): Example: >>> from redisvl.index import SearchIndex >>> index = SearchIndex.from_yaml("schema.yaml") + >>> index.connect("redis://localhost:6379") >>> index.create(overwrite=True) >>> index.load(data) # data is an iterable of dictionaries """ - def __init__( - self, - name: str, - prefix: str = "rvl", - storage_type: str = "hash", - fields: Optional[List["Field"]] = None, - ): - super().__init__(name, prefix, storage_type, fields) - @classmethod def from_existing( cls, name: str, - url: Optional[str] = None, + redis_url: Optional[str] = None, + key_separator: str = ":", fields: Optional[List["Field"]] = None, **kwargs, ): @@ -280,8 +392,10 @@ def from_existing( Args: name (str): Index name. - url (Optional[str], optional): Redis URL. REDIS_URL env var + redis_url (Optional[str], optional): Redis URL. REDIS_URL env var is used if not provided. Defaults to None. + key_separator (str, optional): Separator char to combine prefix and + key value for constructing redis keys. Defaults to ":". fields (Optional[List[Field]], optional): List of Redis search fields to include in the schema. Defaults to None. @@ -290,10 +404,9 @@ def from_existing( Raises: redis.exceptions.ResponseError: If the index does not exist. - ValueError: If the REDIS_URL env var is not set and url is not provided. - + ValueError: If the redis url is not accessible. """ - client = get_redis_connection(url, **kwargs) + client = get_redis_connection(redis_url, **kwargs) info = convert_bytes(client.ft(name).info()) index_definition = make_dict(info["index_definition"]) storage_type = index_definition["key_type"].lower() @@ -302,36 +415,40 @@ def from_existing( name=name, storage_type=storage_type, prefix=prefix, + key_separator=key_separator, fields=fields, ) instance.set_client(client) return instance - def connect(self, url: Optional[str] = None, **kwargs): + def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance. Args: - url (str): Redis URL. REDIS_URL env var is used if not provided. + redis_url (Optional[str], optional): Redis URL. REDIS_URL env var is + used if not provided. Raises: redis.exceptions.ConnectionError: If the connection to Redis fails. - ValueError: If the REDIS_URL env var is not set and url is not provided. + ValueError: If the redis url is not accessible. """ - self._redis_conn = get_redis_connection(url, **kwargs) + self._redis_conn = get_redis_connection(redis_url, **kwargs) return self @check_connected("_redis_conn") - def create(self, overwrite: Optional[bool] = False): + @check_modules_present("_redis_conn") + def create(self, overwrite: bool = False) -> None: """Create an index in Redis from this SearchIndex object. Args: - overwrite (bool, optional): Overwrite the index if it already exists. Defaults to False. + overwrite (bool, optional): Whether to overwrite the index if it + already exists. Defaults to False. Raises: - redis.exceptions.ResponseError: If the index already exists. + RuntimeError: If the index already exists and 'overwrite' is False. + ValueError: If no fields are defined for the index. """ - check_redis_modules_exist(self._redis_conn) - + # Check that fields are defined. if not self._fields: raise ValueError("No fields defined for index") if not isinstance(overwrite, bool): @@ -344,25 +461,23 @@ def create(self, overwrite: Optional[bool] = False): print("Index already exists, overwriting.") self.delete() - # set storage_type, default to hash - storage_type = IndexType.HASH - # TODO - enable JSON support - # if self._storage.lower() == "json": - # storage_type = IndexType.JSON - - # Create Index - # will raise correct response error if index already exists + # Create the index with the specified fields and settings. self._redis_conn.ft(self._name).create_index( # type: ignore fields=self._fields, - definition=IndexDefinition(prefix=[self._prefix], index_type=storage_type), + definition=IndexDefinition( + prefix=[self._prefix], index_type=self._storage.type + ), ) @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() def delete(self, drop: bool = True): """Delete the search index. Args: - drop (bool, optional): Delete the documents in the index. Defaults to True. + drop (bool, optional): Delete the documents in the index. + Defaults to True. raises: redis.exceptions.ResponseError: If the index does not exist. @@ -371,61 +486,94 @@ def delete(self, drop: bool = True): self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore @check_connected("_redis_conn") + @check_modules_present("_redis_conn") def load( self, - data: Iterable[Dict[str, Any]], + data: Iterable[Any], key_field: Optional[str] = None, + keys: Optional[Iterable[str]] = None, + ttl: Optional[int] = None, preprocess: Optional[Callable] = None, + batch_size: Optional[int] = None, **kwargs, ): - """Load data into Redis and index using this SearchIndex object. + """Load a batch of objects to Redis. Args: - data (Iterable[Dict[str, Any]]): An iterable of dictionaries - containing the data to be indexed. - key_field (Optional[str], optional): A field within the record to - use in the Redis hash key. - preprocess (Optional[Callable], optional): An optional preprocessor function - that mutates the individual record before writing to redis. + data (Iterable[Any]): An iterable of objects to store. + key_field (Optional[str], optional): Field used as the key for each + object. Defaults to None. + keys (Optional[Iterable[str]], optional): Optional iterable of keys. + Must match the length of objects if provided. Defaults to None. + ttl (Optional[int], optional): Time-to-live in seconds for each key. + Defaults to None. + preprocess (Optional[Callable], optional): A function to preprocess + objects before storage. Defaults to None. + batch_size (Optional[int], optional): Number of objects to write in + a single Redis pipeline execution. Defaults to class's + default batch size. - raises: - redis.exceptions.ResponseError: If the index does not exist. + Raises: + ValueError: If the length of provided keys does not match the length + of objects. Example: >>> data = [{"foo": "bar"}, {"test": "values"}] - >>> def func(record: dict): record["new"]="value";return record + >>> def func(record: dict): + >>> record["new"] = "value" + >>> return record >>> index.load(data, preprocess=func) """ - # TODO -- should we return a count of the upserts? or some kind of metadata? - if data: - if not isinstance(data, Iterable): - if not isinstance(data[0], dict): - raise TypeError("data must be an iterable of dictionaries") - - # Check if outer interface passes in TTL on load - ttl = kwargs.get("ttl") - with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore - for record in data: - key = self._create_key(record, key_field) - # Optionally preprocess the record and validate type - if preprocess: - try: - record = preprocess(record) - except Exception as e: - raise RuntimeError( - "Error while preprocessing records on load" - ) from e - if not isinstance(record, dict): - raise TypeError( - f"Individual records must be of type dict, got type {type(record)}" - ) - # Write the record to Redis - pipe.hset(key, mapping=record) # type: ignore - if ttl: - pipe.expire(key, ttl) - pipe.execute() + self._storage.write( + self.client, + objects=data, + key_field=key_field, + keys=keys, + ttl=ttl, + preprocess=preprocess, + batch_size=batch_size, + ) @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() + def search(self, *args, **kwargs) -> Union["Result", Any]: + """Perform a search on this index. + + Wrapper around redis.search.Search that adds the index name + to the search query and passes along the rest of the arguments + to the redis-py ft.search() method. + + Returns: + Union["Result", Any]: Search results. + """ + results = self._redis_conn.ft(self._name).search( # type: ignore + *args, **kwargs + ) + return results + + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() + def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: + """Run a query on this index. + + This is similar to the search method, but takes a BaseQuery + object directly (does not allow for the usage of a raw + redis query string) and post-processes results of the search. + + Args: + query (BaseQuery): The query to run. + + Returns: + List[Result]: A list of search results. + """ + results = self.search(query.query, query_params=query.params) + # post process the results + return process_results(results, query=query, storage_type=self._storage_type) + + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") def exists(self) -> bool: """Check if the index exists in Redis. @@ -435,6 +583,17 @@ def exists(self) -> bool: indices = convert_bytes(self._redis_conn.execute_command("FT._LIST")) # type: ignore return self._name in indices + @check_connected("_redis_conn") + @check_modules_present("_redis_conn") + @check_index_exists() + def info(self) -> Dict[str, Any]: + """Get information about the index. + + Returns: + dict: A dictionary containing the information about the index. + """ + return convert_bytes(self._redis_conn.ft(self._name).info()) # type: ignore + class AsyncSearchIndex(SearchIndexBase): """A class for interacting with Redis as a vector database asynchronously. @@ -445,24 +604,17 @@ class AsyncSearchIndex(SearchIndexBase): Example: >>> from redisvl.index import AsyncSearchIndex >>> index = AsyncSearchIndex.from_yaml("schema.yaml") + >>> index.connect("redis://localhost:6379") >>> await index.create(overwrite=True) >>> await index.load(data) # data is an iterable of dictionaries """ - def __init__( - self, - name: str, - prefix: str = "rvl", - storage_type: str = "hash", - fields: Optional[List["Field"]] = None, - ): - super().__init__(name, prefix, storage_type, fields) - @classmethod async def from_existing( cls, name: str, - url: Optional[str] = None, + redis_url: Optional[str] = None, + key_separator: str = ":", fields: Optional[List["Field"]] = None, **kwargs, ): @@ -470,20 +622,21 @@ async def from_existing( Args: name (str): Index name. - url (Optional[str], optional): Redis URL. REDIS_URL env var + redis_url (Optional[str], optional): Redis URL. REDIS_URL env var is used if not provided. Defaults to None. + key_separator (str, optional): Separator char to combine prefix and + key value for constructing redis keys. Defaults to ":". fields (Optional[List[Field]], optional): List of Redis search fields to include in the schema. Defaults to None. Returns: - SearchIndex: A SearchIndex object. + AsyncSearchIndex: An AsyncSearchIndex object. Raises: redis.exceptions.ResponseError: If the index does not exist. - ValueError: If the REDIS_URL env var is not set and url is not provided. - + ValueError: If the Redis URL is not accessible. """ - client = get_async_redis_connection(url, **kwargs) + client = get_async_redis_connection(redis_url, **kwargs) info = convert_bytes(await client.ft(name).info()) index_definition = make_dict(info["index_definition"]) storage_type = index_definition["key_type"].lower() @@ -492,37 +645,38 @@ async def from_existing( name=name, storage_type=storage_type, prefix=prefix, + key_separator=key_separator, fields=fields, ) instance.set_client(client) return instance - def connect(self, url: Optional[str] = None, **kwargs): + def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance. Args: - url (str): Redis URL. REDIS_URL env var is used if not provided. + redis_url (Optional[str], optional): Redis URL. REDIS_URL env var is + used if not provided. Raises: redis.exceptions.ConnectionError: If the connection to Redis fails. - ValueError: If no Redis URL is provided and REDIS_URL env var is not set. + ValueError: If the Redis URL is not accessible. """ - self._redis_conn = get_async_redis_connection(url, **kwargs) + self._redis_conn = get_async_redis_connection(redis_url, **kwargs) return self - @check_connected("_redis_conn") - async def create(self, overwrite: Optional[bool] = False): - """Create an index in Redis from this SearchIndex object. + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") + async def create(self, overwrite: bool = False) -> None: + """Asynchronously create an index in Redis from this SearchIndex object. Args: - overwrite (bool, optional): Overwrite the index if it already exists. Defaults to False. + overwrite (bool, optional): Whether to overwrite the index if it + already exists. Defaults to False. Raises: - redis.exceptions.ResponseError: If the index already exists. + RuntimeError: If the index already exists and 'overwrite' is False. """ - # TODO - enable async version of this - # check_redis_modules_exist(self._redis_conn) - if not self._fields: raise ValueError("No fields defined for index") if not isinstance(overwrite, bool): @@ -535,24 +689,23 @@ async def create(self, overwrite: Optional[bool] = False): print("Index already exists, overwriting.") await self.delete() - # set storage_type, default to hash - storage_type = IndexType.HASH - # TODO - enable JSON support - # if self._storage.lower() == "json": - # storage_type = IndexType.JSON - - # Create Index + # Create Index with proper IndexType await self._redis_conn.ft(self._name).create_index( # type: ignore fields=self._fields, - definition=IndexDefinition(prefix=[self._prefix], index_type=storage_type), + definition=IndexDefinition( + prefix=[self._prefix], index_type=self._storage.type + ), ) - @check_connected("_redis_conn") + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") + @check_async_index_exists() async def delete(self, drop: bool = True): """Delete the search index. Args: - drop (bool, optional): Delete the documents in the index. Defaults to True. + drop (bool, optional): Delete the documents in the index. + Defaults to True. Raises: redis.exceptions.ResponseError: If the index does not exist. @@ -560,61 +713,58 @@ async def delete(self, drop: bool = True): # Delete the search index await self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore - @check_connected("_redis_conn") + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") async def load( self, - data: Iterable[Dict[str, Any]], - concurrency: int = 10, + data: Iterable[Any], key_field: Optional[str] = None, + keys: Optional[Iterable[str]] = None, + ttl: Optional[int] = None, preprocess: Optional[Callable] = None, + concurrency: Optional[int] = None, **kwargs, ): - """Load data into Redis and index using this SearchIndex object. + """Asynchronously load objects to Redis with concurrency control. Args: - data (Iterable[Dict[str, Any]]): An iterable of dictionaries - containing the data to be indexed. - concurrency (int, optional): Number of concurrent tasks to run. Defaults to 10. - key_field (Optional[str], optional): A field within the record to - use in the Redis hash key. - preprocess (Optional[Callable], optional): An optional preprocessor function - that mutates the individual record before writing to redis. + data (Iterable[Any]): An iterable of objects to store. + key_field (Optional[str], optional): Field used as the key for each + object. Defaults to None. + keys (Optional[Iterable[str]], optional): Optional iterable of keys. + Must match the length of objects if provided. Defaults to None. + ttl (Optional[int], optional): Time-to-live in seconds for each key. + Defaults to None. + preprocess (Optional[Callable], optional): An async function to + preprocess objects before storage. Defaults to None. + concurrency (Optional[int], optional): The maximum number of + concurrent write operations. Defaults to class's default + concurrency level. Raises: - redis.exceptions.ResponseError: If the index does not exist. + ValueError: If the length of provided keys does not match the + length of objects. Example: >>> data = [{"foo": "bar"}, {"test": "values"}] - >>> def func(record: dict): record["new"]="value";return record + >>> async def func(record: dict): + >>> record["new"] = "value" + >>> return record >>> await index.load(data, preprocess=func) """ - ttl = kwargs.get("ttl") - semaphore = asyncio.Semaphore(concurrency) - - async def _load(record: dict): - async with semaphore: - key = self._create_key(record, key_field) - # Optionally preprocess the record and validate type - if preprocess: - try: - record = preprocess(record) - except Exception as e: - raise RuntimeError( - "Error while preprocessing records on load" - ) from e - if not isinstance(record, dict): - raise TypeError( - f"Individual records must be of type dict, got type {type(record)}" - ) - # Write the record to Redis - await self._redis_conn.hset(key, mapping=record) # type: ignore - if ttl: - await self._redis_conn.expire(key, ttl) # type: ignore - - # Gather with concurrency - await asyncio.gather(*[_load(record) for record in data]) + await self._storage.awrite( + self.client, + objects=data, + key_field=key_field, + keys=keys, + ttl=ttl, + preprocess=preprocess, + concurrency=concurrency, + ) - @check_connected("_redis_conn") + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") + @check_async_index_exists() async def search(self, *args, **kwargs) -> Union["Result", Any]: """Perform a search on this index. @@ -625,11 +775,12 @@ async def search(self, *args, **kwargs) -> Union["Result", Any]: Returns: Union["Result", Any]: Search results. """ - results = await self._redis_conn.ft(self._name).search( # type: ignore - *args, **kwargs - ) + results = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore return results + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") + @check_async_index_exists() async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: """Run a query on this index. @@ -644,11 +795,11 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]: List[Result]: A list of search results. """ results = await self.search(query.query, query_params=query.params) - if isinstance(query, CountQuery): - return results.total - return process_results(results) + # post process the results + return process_results(results, query=query, storage_type=self._storage_type) - @check_connected("_redis_conn") + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") async def exists(self) -> bool: """Check if the index exists in Redis. @@ -657,3 +808,16 @@ async def exists(self) -> bool: """ indices = await self._redis_conn.execute_command("FT._LIST") # type: ignore return self._name in convert_bytes(indices) + + @check_async_connected("_redis_conn") + @check_async_modules_present("_redis_conn") + @check_async_index_exists() + async def info(self) -> Dict[str, Any]: + """Get information about the index. + + Returns: + dict: A dictionary containing the information about the index. + """ + return convert_bytes( + await self._redis_conn.ft(self._name).info() # type: ignore + ) diff --git a/redisvl/llmcache/base.py b/redisvl/llmcache/base.py index fb6248c7..574c114f 100644 --- a/redisvl/llmcache/base.py +++ b/redisvl/llmcache/base.py @@ -26,7 +26,8 @@ def store( vector: Optional[List[float]] = None, metadata: Optional[dict] = {}, ) -> None: - """Stores the specified key-value pair in the cache along with metadata.""" + """Stores the specified key-value pair in the cache along with + metadata.""" raise NotImplementedError def _refresh_ttl(self, key: str): diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 60f2452b..84b0e38e 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -63,7 +63,7 @@ def __init__( index = SearchIndex( name=index_name, prefix=prefix, fields=self._default_fields ) - index.connect(url=redis_url, **connection_args) + index.connect(redis_url=redis_url, **connection_args) else: raise ValueError( "Index name and prefix must be provided if not constructing from an existing index." @@ -135,11 +135,12 @@ def set_threshold(self, threshold: float): self._threshold = float(threshold) def clear(self): - """Clear the LLMCache of all keys in the index""" + """Clear the LLMCache of all keys in the index.""" client = self._index.client + prefix = self._index.prefix if client: with client.pipeline(transaction=False) as pipe: - for key in client.scan_iter(match=f"{self._index._prefix}:*"): + for key in client.scan_iter(match=f"{prefix}:*"): pipe.delete(key) pipe.execute() else: diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index e2d46834..aec14454 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -95,7 +95,7 @@ class Tag(FilterField): SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None)) def __init__(self, field: str): - """Create a Tag FilterField + """Create a Tag FilterField. Args: field (str): The name of the tag field in the index to be queried against @@ -121,7 +121,7 @@ def _set_tag_value( @check_operator_misuse def __eq__(self, other: Union[List[str], str]) -> "FilterExpression": - """Create a Tag equality filter expression + """Create a Tag equality filter expression. Args: other (Union[List[str], str]): The tag(s) to filter on. @@ -135,7 +135,7 @@ def __eq__(self, other: Union[List[str], str]) -> "FilterExpression": @check_operator_misuse def __ne__(self, other) -> "FilterExpression": - """Create a Tag inequality filter expression + """Create a Tag inequality filter expression. Args: other (Union[List[str], str]): The tag(s) to filter on. @@ -152,7 +152,7 @@ def _formatted_tag_value(self) -> str: return "|".join([self.escaper.escape(tag) for tag in self._value]) def __str__(self) -> str: - """Return the Redis Query syntax for a Tag filter expression""" + """Return the Redis Query syntax for a Tag filter expression.""" if not self._value: return "*" @@ -175,7 +175,7 @@ def __init__(self, longitude: float, latitude: float, unit: str = "km"): class GeoRadius(GeoSpec): - """A GeoRadius is a GeoSpec representing a geographic radius""" + """A GeoRadius is a GeoSpec representing a geographic radius.""" def __init__( self, @@ -194,7 +194,6 @@ def __init__( Raises: ValueError: If the unit is not one of "m", "km", "mi", or "ft". - """ super().__init__(longitude, latitude, unit) self._radius = radius @@ -204,24 +203,22 @@ def get_args(self) -> List[Union[float, int, str]]: class Geo(FilterField): - """A Geo is a FilterField representing a geographic (lat/lon) - field in a Redis index. - - """ + """A Geo is a FilterField representing a geographic (lat/lon) field in a + Redis index.""" OPERATORS: Dict[FilterOperator, str] = { FilterOperator.EQ: "==", FilterOperator.NE: "!=", } OPERATOR_MAP: Dict[FilterOperator, str] = { - FilterOperator.EQ: "@%s:[%f %f %i %s]", - FilterOperator.NE: "(-@%s:[%f %f %i %s])", + FilterOperator.EQ: "@%s:[%s %s %i %s]", + FilterOperator.NE: "(-@%s:[%s %s %i %s])", } SUPPORTED_VAL_TYPES = (GeoSpec, type(None)) @check_operator_misuse def __eq__(self, other) -> "FilterExpression": - """Create a Geographic equality filter expression + """Create a Geographic equality filter expression. Args: other (GeoSpec): The geographic spec to filter on. @@ -235,7 +232,7 @@ def __eq__(self, other) -> "FilterExpression": @check_operator_misuse def __ne__(self, other) -> "FilterExpression": - """Create a Geographic inequality filter expression + """Create a Geographic inequality filter expression. Args: other (GeoSpec): The geographic spec to filter on. @@ -248,7 +245,7 @@ def __ne__(self, other) -> "FilterExpression": return FilterExpression(str(self)) def __str__(self) -> str: - """Return the Redis Query syntax for a Geographic filter expression""" + """Return the Redis Query syntax for a Geographic filter expression.""" if not self._value: return "*" @@ -270,17 +267,17 @@ class Num(FilterField): FilterOperator.GE: ">=", } OPERATOR_MAP: Dict[FilterOperator, str] = { - FilterOperator.EQ: "@%s:[%i %i]", - FilterOperator.NE: "(-@%s:[%i %i])", - FilterOperator.GT: "@%s:[(%i +inf]", - FilterOperator.LT: "@%s:[-inf (%i]", - FilterOperator.GE: "@%s:[%i +inf]", - FilterOperator.LE: "@%s:[-inf %i]", + FilterOperator.EQ: "@%s:[%s %s]", + FilterOperator.NE: "(-@%s:[%s %s])", + FilterOperator.GT: "@%s:[(%s +inf]", + FilterOperator.LT: "@%s:[-inf (%s]", + FilterOperator.GE: "@%s:[%s +inf]", + FilterOperator.LE: "@%s:[-inf %s]", } SUPPORTED_VAL_TYPES = (int, float, type(None)) def __eq__(self, other: int) -> "FilterExpression": - """Create a Numeric equality filter expression + """Create a Numeric equality filter expression. Args: other (int): The value to filter on. @@ -293,7 +290,7 @@ def __eq__(self, other: int) -> "FilterExpression": return FilterExpression(str(self)) def __ne__(self, other: int) -> "FilterExpression": - """Create a Numeric inequality filter expression + """Create a Numeric inequality filter expression. Args: other (int): The value to filter on. @@ -306,7 +303,7 @@ def __ne__(self, other: int) -> "FilterExpression": return FilterExpression(str(self)) def __gt__(self, other: int) -> "FilterExpression": - """Create a Numeric greater than filter expression + """Create a Numeric greater than filter expression. Args: other (int): The value to filter on. @@ -319,7 +316,7 @@ def __gt__(self, other: int) -> "FilterExpression": return FilterExpression(str(self)) def __lt__(self, other: int) -> "FilterExpression": - """Create a Numeric less than filter expression + """Create a Numeric less than filter expression. Args: other (int): The value to filter on. @@ -332,7 +329,7 @@ def __lt__(self, other: int) -> "FilterExpression": return FilterExpression(str(self)) def __ge__(self, other: int) -> "FilterExpression": - """Create a Numeric greater than or equal to filter expression + """Create a Numeric greater than or equal to filter expression. Args: other (int): The value to filter on. @@ -345,7 +342,7 @@ def __ge__(self, other: int) -> "FilterExpression": return FilterExpression(str(self)) def __le__(self, other: int) -> "FilterExpression": - """Create a Numeric less than or equal to filter expression + """Create a Numeric less than or equal to filter expression. Args: other (int): The value to filter on. @@ -358,7 +355,7 @@ def __le__(self, other: int) -> "FilterExpression": return FilterExpression(str(self)) def __str__(self) -> str: - """Return the Redis Query syntax for a Numeric filter expression""" + """Return the Redis Query syntax for a Numeric filter expression.""" if not self._value: return "*" @@ -389,8 +386,8 @@ class Text(FilterField): @check_operator_misuse def __eq__(self, other: str) -> "FilterExpression": - """Create a Text equality filter expression. These expressions - yield filters that enforce an exact match on the supplied term(s). + """Create a Text equality filter expression. These expressions yield + filters that enforce an exact match on the supplied term(s). Args: other (str): The text value to filter on. @@ -405,8 +402,8 @@ def __eq__(self, other: str) -> "FilterExpression": @check_operator_misuse def __ne__(self, other: str) -> "FilterExpression": """Create a Text inequality filter expression. These expressions yield - negated filters on exact matches on the supplied term(s). Opposite of - an equality filter expression. + negated filters on exact matches on the supplied term(s). Opposite of an + equality filter expression. Args: other (str): The text value to filter on. @@ -420,8 +417,9 @@ def __ne__(self, other: str) -> "FilterExpression": def __mod__(self, other: str) -> "FilterExpression": """Create a Text "LIKE" filter expression. A flexible expression that - yields filters that can use a variety of additional operators like - wildcards (*), fuzzy matches (%%), or combinatorics (|) of the supplied term(s). + yields filters that can use a variety of additional operators like + wildcards (*), fuzzy matches (%%), or combinatorics (|) of the supplied + term(s). Args: other (str): The text value to filter on. diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 36efc014..c24ad385 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -210,7 +210,6 @@ def __init__( Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression - """ super().__init__( vector, @@ -290,7 +289,6 @@ def __init__( Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression - """ super().__init__( vector, diff --git a/redisvl/schema.py b/redisvl/schema.py index 55dbad7f..eebf7318 100644 --- a/redisvl/schema.py +++ b/redisvl/schema.py @@ -1,6 +1,6 @@ +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 import yaml from pydantic import BaseModel, Field, validator @@ -17,6 +17,7 @@ class BaseField(BaseModel): name: str = Field(...) sortable: Optional[bool] = False + as_name: Optional[str] = None class TextFieldSchema(BaseField): @@ -32,6 +33,7 @@ def as_field(self): no_stem=self.no_stem, phonetic_matcher=self.phonetic_matcher, sortable=self.sortable, + as_name=self.as_name, ) @@ -45,17 +47,18 @@ def as_field(self): separator=self.separator, case_sensitive=self.case_sensitive, sortable=self.sortable, + as_name=self.as_name, ) class NumericFieldSchema(BaseField): def as_field(self): - return NumericField(self.name, sortable=self.sortable) + return NumericField(self.name, sortable=self.sortable, as_name=self.as_name) class GeoFieldSchema(BaseField): def as_field(self): - return GeoField(self.name, sortable=self.sortable) + return GeoField(self.name, sortable=self.sortable, as_name=self.as_name) class BaseVectorField(BaseModel): @@ -65,6 +68,7 @@ class BaseVectorField(BaseModel): datatype: str = Field(default="FLOAT32") distance_metric: str = Field(default="COSINE") initial_cap: Optional[int] = None + as_name: Optional[str] = None @validator("algorithm", "datatype", "distance_metric", pre=True) def uppercase_strings(cls, v): @@ -90,7 +94,7 @@ def as_field(self): field_data = super().as_field() if self.block_size is not None: field_data["BLOCK_SIZE"] = self.block_size - return VectorField(self.name, self.algorithm, field_data) + return VectorField(self.name, self.algorithm, field_data, as_name=self.as_name) class HNSWVectorField(BaseVectorField): @@ -111,13 +115,22 @@ def as_field(self): "EPSILON": self.epsilon, } ) - return VectorField(self.name, self.algorithm, field_data) + return VectorField(self.name, self.algorithm, field_data, as_name=self.as_name) + + +class StorageType(Enum): + HASH = "hash" + JSON = "json" class IndexModel(BaseModel): - name: str = Field(...) - prefix: Optional[str] = Field(default="") - storage_type: Optional[str] = Field(default="hash") + """Represents the schema for an index, including its name, optional prefix, + and the storage type used.""" + + name: str + prefix: str = "rvl" + key_separator: str = ":" + storage_type: StorageType = StorageType.HASH class FieldsModel(BaseModel): @@ -132,12 +145,6 @@ class SchemaModel(BaseModel): index: IndexModel = Field(...) fields: FieldsModel = Field(...) - @validator("index") - def validate_index(cls, v): - if v.storage_type not in ["hash", "json"]: - raise ValueError(f"Storage type {v.storage_type} not supported") - return v - @property def index_fields(self): redis_fields = [] @@ -160,21 +167,12 @@ def read_schema(file_path: str): return SchemaModel(**schema) -class MetadataSchemaGenerator: - """ - A class to generate a schema for metadata, categorizing fields into text, numeric, and tag types. - """ +class SchemaGenerator: + """A class to generate a schema for metadata, categorizing fields into text, + numeric, and tag types.""" def _test_numeric(self, value) -> bool: - """ - Test if the given value can be represented as a numeric value. - - Args: - value: The value to test. - - Returns: - bool: True if the value can be converted to float, False otherwise. - """ + """Test if a value is numeric.""" try: float(value) return True @@ -182,72 +180,62 @@ def _test_numeric(self, value) -> bool: return False def _infer_type(self, value) -> Optional[str]: - """ - Infer the type of the given value. - - Args: - value: The value to infer the type of. - - Returns: - Optional[str]: The inferred type of the value, or None if the type is unrecognized or the value is empty. - """ - if value is None or value == "": + """Infer the type of a value.""" + if value in [None, ""]: return None - elif self._test_numeric(value): + if self._test_numeric(value): return "numeric" - elif isinstance(value, (list, set, tuple)) and all( + if isinstance(value, (list, set, tuple)) and all( isinstance(v, str) for v in value ): return "tag" - elif isinstance(value, str): - return "text" - else: - return "unknown" + return "text" if isinstance(value, str) else "unknown" def generate( - self, metadata: Dict[str, Any], strict: Optional[bool] = False + self, metadata: Dict[str, Any], strict: bool = False ) -> Dict[str, List[Dict[str, Any]]]: - """ - Generate a schema from the provided metadata. - - This method categorizes each metadata field into text, numeric, or tag types based on the field values. - It also allows forcing strict type determination by raising an exception if a type cannot be inferred. + """Generate a schema from metadata. Args: - metadata: The metadata dictionary to generate the schema from. - strict: If True, the method will raise an exception for fields where the type cannot be determined. - - Returns: - Dict[str, List[Dict[str, Any]]]: A dictionary with keys 'text', 'numeric', and 'tag', each mapping to a list of field schemas. + metadata (Dict[str, Any]): Metadata object to validate and + generate schema. + strict (bool, optional): Whether to generate schema in strict + mode. Defaults to False. Raises: - ValueError: If the force parameter is True and a field's type cannot be determined. + ValueError: Unable to determine schema field type for a + key-value pair. + + Returns: + Dict[str, List[Dict[str, Any]]]: Output metadata schema. """ result: Dict[str, List[Dict[str, Any]]] = {"text": [], "numeric": [], "tag": []} + field_classes = { + "text": TextFieldSchema, + "tag": TagFieldSchema, + "numeric": NumericFieldSchema, + } for key, value in metadata.items(): field_type = self._infer_type(value) - if field_type in ["unknown", None]: + if field_type is None or field_type == "unknown": if strict: raise ValueError( - f"Unable to determine field type for key '{key}' with value '{value}'" + f"Unable to determine field type for key '{key}' with" + f" value '{value}'" ) print( - f"Warning: Unable to determine field type for key '{key}' with value '{value}'" + f"Warning: Unable to determine field type for key '{key}'" + f" with value '{value}'" ) continue - # Extract the field class with defaults - field_class = { - "text": TextFieldSchema, - "tag": TagFieldSchema, - "numeric": NumericFieldSchema, - }.get( - field_type # type: ignore - ) - - if field_class: - result[field_type].append(field_class(name=key).dict(exclude_none=True)) # type: ignore + if isinstance(field_type, str): + field_class = field_classes.get(field_type) + if field_class: + result[field_type].append( + field_class(name=key).dict(exclude_none=True) + ) return result diff --git a/redisvl/storage.py b/redisvl/storage.py new file mode 100644 index 00000000..a1b9daba --- /dev/null +++ b/redisvl/storage.py @@ -0,0 +1,496 @@ +import asyncio +import uuid +from typing import Any, Callable, Dict, Iterable, List, Optional + +from redis import Redis +from redis.asyncio import Redis as AsyncRedis +from redis.commands.search.indexDefinition import IndexType + +from redisvl.utils.utils import convert_bytes + + +class BaseStorage: + type: IndexType + DEFAULT_BATCH_SIZE: int = 200 + DEFAULT_WRITE_CONCURRENCY: int = 20 + + def __init__(self, prefix: str, key_separator: str): + """Initialize the BaseStorage with a specific prefix and key separator + for Redis keys. + + Args: + prefix (str): The prefix to prepend to each Redis key. + key_separator (str): The separator to use between the prefix and + the key value. + """ + self._prefix = prefix + self._key_separator = key_separator + + @staticmethod + def _key(key_value: str, prefix: str, key_separator: str) -> str: + """Create a Redis key using a combination of a prefix, separator, and + the key value. + + Args: + key_value (str): The unique identifier for the Redis entry. + prefix (str): A prefix to append before the key value. + key_separator (str): A separator to insert between prefix + and key value. + + Returns: + str: The fully formed Redis key. + """ + if not prefix: + return key_value + else: + return f"{prefix}{key_separator}{key_value}" + + def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> str: + """Construct a Redis key for a given object, optionally using a + specified field from the object as the key. + + Args: + obj (Dict[str, Any]): The object from which to construct the key. + key_field (Optional[str], optional): The field to use as the + key, if provided. + + Returns: + str: The constructed Redis key for the object. + + Raises: + ValueError: If the key_field is not found in the object. + """ + if key_field is None: + key_value = uuid.uuid4().hex + else: + try: + key_value = obj[key_field] # type: ignore + except KeyError: + raise ValueError(f"Key field {key_field} not found in record {obj}") + + return self._key( + key_value, prefix=self._prefix, key_separator=self._key_separator + ) + + @staticmethod + def _preprocess(obj: Any, preprocess: Optional[Callable] = None) -> Dict[str, Any]: + """Apply a preprocessing function to the object if provided. + + Args: + preprocess (Optional[Callable], optional): Function to + process the object. + obj (Any): Object to preprocess. + + Returns: + Dict[str, Any]: Processed object as a dictionary. + """ + # optionally preprocess object + if preprocess: + obj = preprocess(obj) + return obj + + @staticmethod + async def _apreprocess( + obj: Any, preprocess: Optional[Callable] = None + ) -> Dict[str, Any]: + """Asynchronously apply a preprocessing function to the object if + provided. + + Args: + preprocess (Optional[Callable], optional): Async function to + process the object. + obj (Any): Object to preprocess. + + Returns: + Dict[str, Any]: Processed object as a dictionary. + """ + # optionally async preprocess object + if preprocess: + obj = await preprocess(obj) + return obj + + def _validate(self, obj: Dict[str, Any]): + """Validate the object before writing to Redis. This method should be + implemented by subclasses. + + Args: + obj (Dict[str, Any]): The object to validate. + """ + raise NotImplementedError + + @staticmethod + def _set(client: Redis, key: str, obj: Dict[str, Any]): + """Synchronously set the value in Redis for the given key. + + Args: + client (Redis): The Redis client instance. + key (str): The key under which to store the object. + obj (Dict[str, Any]): The object to store in Redis. + """ + raise NotImplementedError + + @staticmethod + async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]): + """Asynchronously set the value in Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key under which to store the object. + obj (Dict[str, Any]): The object to store in Redis. + """ + raise NotImplementedError + + @staticmethod + def _get(client: Redis, key: str) -> Dict[str, Any]: + """Synchronously get the value from Redis for the given key. + + Args: + client (Redis): The Redis client instance. + key (str): The key for which to retrieve the object. + + Returns: + Dict[str, Any]: The retrieved object from Redis. + """ + raise NotImplementedError + + @staticmethod + async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]: + """Asynchronously get the value from Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key for which to retrieve the object. + + Returns: + Dict[str, Any]: The retrieved object from Redis. + """ + raise NotImplementedError + + def write( + self, + redis_client: Redis, + objects: Iterable[Any], + key_field: Optional[str] = None, + keys: Optional[Iterable[str]] = None, + ttl: Optional[int] = None, + preprocess: Optional[Callable] = None, + batch_size: Optional[int] = None, + ): + """Write a batch of objects to Redis as hash entries. + + Args: + redis_client (Redis): A Redis client used for writing data. + objects (Iterable[Any]): An iterable of objects to store. + key_field (Optional[str], optional): Field used as the key for + each object. Defaults to None. + keys (Optional[Iterable[str]], optional): Optional iterable of + keys, must match the length of objects if provided. + ttl (Optional[int], optional): Time-to-live in seconds for each + key. Defaults to None. + preprocess (Optional[Callable], optional): A function to preprocess + objects before storage. Defaults to None. + batch_size (Optional[int], optional): Number of objects to write + in a single Redis pipeline execution. + + Raises: + ValueError: If the length of provided keys does not match the + length of objects. + """ + if keys and len(keys) != len(objects): # type: ignore + raise ValueError("Length of keys does not match the length of objects") + + if batch_size is None: + batch_size = ( + self.DEFAULT_BATCH_SIZE + ) # Use default or calculate based on the input data + + keys_iterator = iter(keys) if keys else None + + with redis_client.pipeline(transaction=False) as pipe: + for i, obj in enumerate(objects, start=1): + key = ( + next(keys_iterator) + if keys_iterator + else self._create_key(obj, key_field) + ) + obj = self._preprocess(obj, preprocess) + self._validate(obj) + self._set(pipe, key, obj) + if ttl: + pipe.expire(key, ttl) # Set TTL if provided + # execute mini batch + if i % batch_size == 0: + pipe.execute() + # clean up batches if needed + if i % batch_size != 0: + pipe.execute() + + async def awrite( + self, + redis_client: AsyncRedis, + objects: Iterable[Any], + key_field: Optional[str] = None, + keys: Optional[Iterable[str]] = None, + ttl: Optional[int] = None, + preprocess: Optional[Callable] = None, + concurrency: Optional[int] = None, + ): + """Asynchronously write objects to Redis as hash entries with + concurrency control. + + Args: + redis_client (AsyncRedis): An asynchronous Redis client used + for writing data. + objects (Iterable[Any]): An iterable of objects to store. + key_field (Optional[str], optional): Field used as the key for each + object. Defaults to None. + keys (Optional[Iterable[str]], optional): Optional iterable of keys. + Must match the length of objects if provided. + ttl (Optional[int], optional): Time-to-live in seconds for each key. + Defaults to None. + preprocess (Optional[Callable], optional): An async function to + preprocess objects before storage. Defaults to None. + concurrency (Optional[int], optional): The maximum number of + concurrent write operations. Defaults to class's default + concurrency level. + + Raises: + ValueError: If the length of provided keys does not match the + length of objects. + """ + if keys and len(keys) != len(objects): # type: ignore + raise ValueError("Length of keys does not match the length of objects") + + if not concurrency: + concurrency = self.DEFAULT_WRITE_CONCURRENCY + + semaphore = asyncio.Semaphore(concurrency) + keys_iterator = iter(keys) if keys else None + + async def _load(obj: Dict[str, Any], key: Optional[str] = None) -> None: + async with semaphore: + if key is None: + key = self._create_key(obj, key_field) + obj = await self._apreprocess(obj, preprocess) + self._validate(obj) + await self._aset(redis_client, key, obj) + if ttl: + await redis_client.expire(key, ttl) + + if keys_iterator: + tasks = [ + asyncio.create_task(_load(obj, next(keys_iterator))) for obj in objects + ] + else: + tasks = [asyncio.create_task(_load(obj)) for obj in objects] + + await asyncio.gather(*tasks) + + def get( + self, redis_client: Redis, keys: Iterable[str], batch_size: Optional[int] = None + ) -> List[Dict[str, Any]]: + """Retrieve objects from Redis by keys. + + Args: + redis_client (Redis): Synchronous Redis client. + keys (Iterable[str]): Keys to retrieve from Redis. + batch_size (Optional[int], optional): Number of objects to write + in a single Redis pipeline execution. Defaults to class's + default batch size. + + Returns: + List[Dict[str, Any]]: List of objects pulled from redis. + """ + results: List = [] + + if not isinstance(keys, Iterable): # type: ignore + raise TypeError("Keys must be an iterable of strings") + + if len(keys) == 0: # type: ignore + return [] + + if batch_size is None: + batch_size = ( + self.DEFAULT_BATCH_SIZE + ) # Use default or calculate based on the input data + + # Use a pipeline to batch the retrieval + with redis_client.pipeline(transaction=False) as pipe: + for i, key in enumerate(keys, start=1): + self._get(pipe, key) + if i % batch_size == 0: + results.extend(pipe.execute()) + if i % batch_size != 0: + results.extend(pipe.execute()) + + # Process results + return convert_bytes(results) + + async def aget( + self, + redis_client: AsyncRedis, + keys: Iterable[str], + concurrency: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """Asynchronously retrieve objects from Redis by keys, with concurrency + control. + + Args: + redis_client (AsyncRedis): Asynchronous Redis client. + keys (Iterable[str]): Keys to retrieve from Redis. + concurrency (Optional[int], optional): The number of concurrent + requests to make. + + Returns: + Dict[str, Any]: Dictionary with keys and their corresponding + objects. + """ + if not isinstance(keys, Iterable): # type: ignore + raise TypeError("Keys must be an iterable of strings") + + if len(keys) == 0: # type: ignore + return [] + + if not concurrency: + concurrency = self.DEFAULT_WRITE_CONCURRENCY + + semaphore = asyncio.Semaphore(concurrency) + + async def _get(key: str) -> Dict[str, Any]: + async with semaphore: + result = await self._aget(redis_client, key) + return result + + tasks = [asyncio.create_task(_get(key)) for key in keys] + results = await asyncio.gather(*tasks) + return convert_bytes(results) + + +class HashStorage(BaseStorage): + type: IndexType = IndexType.HASH + + def _validate(self, obj: Dict[str, Any]): + """Validate that the given object is a dictionary, suitable for storage + as a Redis hash. + + Args: + obj (Dict[str, Any]): The object to validate. + + Raises: + TypeError: If the object is not a dictionary. + """ + if not isinstance(obj, dict): + raise TypeError("Object must be a dictionary.") + + @staticmethod + def _set(client: Redis, key: str, obj: Dict[str, Any]): + """Synchronously set a hash value in Redis for the given key. + + Args: + client (Redis): The Redis client instance. + key (str): The key under which to store the hash. + obj (Dict[str, Any]): The hash to store in Redis. + """ + client.hset(name=key, mapping=obj) # type: ignore + + @staticmethod + async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]): + """Asynchronously set a hash value in Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key under which to store the hash. + obj (Dict[str, Any]): The hash to store in Redis. + """ + await client.hset(name=key, mapping=obj) # type: ignore + + @staticmethod + def _get(client: Redis, key: str) -> Dict[str, Any]: + """Synchronously retrieve a hash value from Redis for the given key. + + Args: + client (Redis): The Redis client instance. + key (str): The key for which to retrieve the hash. + + Returns: + Dict[str, Any]: The retrieved hash from Redis. + """ + return client.hgetall(key) + + @staticmethod + async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]: + """Asynchronously retrieve a hash value from Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key for which to retrieve the hash. + + Returns: + Dict[str, Any]: The retrieved hash from Redis. + """ + return await client.hgetall(key) + + +class JsonStorage(BaseStorage): + type: IndexType = IndexType.JSON + + def _validate(self, obj: Dict[str, Any]): + """Validate that the given object is a dictionary, suitable for JSON + serialization. + + Args: + obj (Dict[str, Any]): The object to validate. + + Raises: + TypeError: If the object is not a dictionary. + """ + if not isinstance(obj, dict): + raise TypeError("Object must be a dictionary.") + + @staticmethod + def _set(client: Redis, key: str, obj: Dict[str, Any]): + """Synchronously set a JSON obj in Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key under which to store the JSON obj. + obj (Dict[str, Any]): The JSON obj to store in Redis. + """ + client.json().set(key, "$", obj) + + @staticmethod + async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]): + """Asynchronously set a JSON obj in Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key under which to store the JSON obj. + obj (Dict[str, Any]): The JSON obj to store in Redis. + """ + await client.json().set(key, "$", obj) + + @staticmethod + def _get(client: Redis, key: str) -> Dict[str, Any]: + """Synchronously retrieve a JSON obj from Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key for which to retrieve the JSON obj. + + Returns: + Dict[str, Any]: The retrieved JSON obj from Redis. + """ + return client.json().get(key) + + @staticmethod + async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]: + """Asynchronously retrieve a JSON obj from Redis for the given key. + + Args: + client (AsyncRedis): The Redis client instance. + key (str): The key for which to retrieve the JSON obj. + + Returns: + Dict[str, Any]: The retrieved JSON obj from Redis. + """ + return await client.json().get(key) diff --git a/redisvl/utils/connection.py b/redisvl/utils/connection.py index e9a04647..a0f9da4a 100644 --- a/redisvl/utils/connection.py +++ b/redisvl/utils/connection.py @@ -1,5 +1,4 @@ import os -from functools import wraps from typing import Optional # TODO: handle connection errors. @@ -32,7 +31,7 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs): def get_address_from_env(): - """Get a redis connection from environment variables + """Get a redis connection from environment variables. Returns: str: Redis URL @@ -41,18 +40,3 @@ def get_address_from_env(): if not addr: raise ValueError("REDIS_URL env var not set") return addr - - -def check_connected(client_variable_name: str): - def decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if getattr(self, client_variable_name) is None: - raise ValueError( - f"SearchIndex.connect() must be called before calling {func.__name__}" - ) - return func(self, *args, **kwargs) - - return wrapper - - return decorator diff --git a/redisvl/utils/token_escaper.py b/redisvl/utils/token_escaper.py index 10260866..53e47a73 100644 --- a/redisvl/utils/token_escaper.py +++ b/redisvl/utils/token_escaper.py @@ -3,8 +3,9 @@ class TokenEscaper: - """ - Escape punctuation within an input string. Adapted from RedisOM Python. + """Escape punctuation within an input string. + + Adapted from RedisOM Python. """ # Characters that RediSearch requires us to escape during queries. diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index f9757f73..267359ed 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,8 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List - -if TYPE_CHECKING: - from redis.commands.search.result import Result - from redis.commands.search.document import Document +from typing import Any, List import numpy as np @@ -57,18 +53,27 @@ def check_redis_modules_exist(client) -> None: raise ValueError(error_message) +async def check_async_redis_modules_exist(client) -> None: + """Check if the correct Redis modules are installed.""" + installed_modules = await client.module_list() + installed_modules = { + module[b"name"].decode("utf-8"): module for module in installed_modules + } + for module in REDIS_REQUIRED_MODULES: + if module["name"] in installed_modules and int( + installed_modules[module["name"]][b"ver"] + ) >= int( + module["ver"] + ): # type: ignore[call-overload] + return + # otherwise raise error + error_message = ( + "You must add the RediSearch (>= 2.4) module from Redis Stack. " + "Please refer to Redis Stack docs: https://redis.io/docs/stack/" + ) + raise ValueError(error_message) + + def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: """Convert a list of floats into a numpy byte string.""" return np.array(array).astype(dtype).tobytes() - - -def process_results(results: "Result") -> List[Dict[str, Any]]: - """Convert a list of search Result objects into a list of document dicts""" - - def _process(doc: "Document") -> Dict[str, Any]: - d = doc.__dict__ - if "payload" in d: - del d["payload"] - return d - - return [_process(doc) for doc in results.docs] diff --git a/redisvl/vectorize/text/huggingface.py b/redisvl/vectorize/text/huggingface.py index 53db328f..50755716 100644 --- a/redisvl/vectorize/text/huggingface.py +++ b/redisvl/vectorize/text/huggingface.py @@ -61,8 +61,8 @@ def embed_many( batch_size: int = 1000, as_buffer: bool = False, ) -> List[List[float]]: - """Asynchronously embed many chunks of texts using the Hugging Face sentence - transformer. + """Asynchronously embed many chunks of texts using the Hugging Face + sentence transformer. Args: texts (List[str]): List of text chunks to embed. diff --git a/redisvl/vectorize/text/openai.py b/redisvl/vectorize/text/openai.py index b9a83162..368e760a 100644 --- a/redisvl/vectorize/text/openai.py +++ b/redisvl/vectorize/text/openai.py @@ -10,10 +10,11 @@ class OpenAITextVectorizer(BaseVectorizer): - """OpenAI text vectorizer + """OpenAI text vectorizer. - This vectorizer uses the OpenAI API to create embeddings for text. It requires an - API key to be passed in the api_config dictionary. The API key can be obtained from + This vectorizer uses the OpenAI API to create embeddings for text. It + requires an API key to be passed in the api_config dictionary. The API key + can be obtained from https://api.openai.com/. """ diff --git a/redisvl/vectorize/text/vertexai.py b/redisvl/vectorize/text/vertexai.py index a96d3aa9..7dbc2785 100644 --- a/redisvl/vectorize/text/vertexai.py +++ b/redisvl/vectorize/text/vertexai.py @@ -7,10 +7,11 @@ class VertexAITextVectorizer(BaseVectorizer): - """VertexAI text vectorizer + """VertexAI text vectorizer. - This vectorizer uses the VertexAI Palm 2 embedding model API to create embeddings for text. It requires an - active GCP project, location, and application credentials. + This vectorizer uses the VertexAI Palm 2 embedding model API to create + embeddings for text. It requires an active GCP project, location, and + application credentials. """ def __init__( diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 8436a5a5..49cad43e 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -211,7 +211,7 @@ def filter_test( location=None, distance_threshold=0.2, ): - """Utility function to test filters""" + """Utility function to test filters.""" # set the new filter query.set_filter(_filter) diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 95aaf629..437519e4 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -1,9 +1,12 @@ from pprint import pprint import numpy as np +import pytest from redisvl.index import SearchIndex from redisvl.query import VectorQuery +from redisvl.schema import StorageType +from redisvl.utils.utils import array_to_buffer data = [ { @@ -12,7 +15,7 @@ "age": 1, "job": "engineer", "credit_score": "high", - "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), + "user_embedding": [0.1, 0.1, 0.5], }, { "id": 2, @@ -20,7 +23,7 @@ "age": 2, "job": "doctor", "credit_score": "low", - "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), + "user_embedding": [0.1, 0.1, 0.5], }, { "id": 3, @@ -28,14 +31,14 @@ "age": 3, "job": "dentist", "credit_score": "medium", - "user_embedding": np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(), + "user_embedding": [0.9, 0.9, 0.1], }, ] -schema = { +hash_schema = { "index": { - "name": "user_index", - "prefix": "users", + "name": "user_index_hash", + "prefix": "users_hash", "storage_type": "hash", }, "fields": { @@ -54,16 +57,51 @@ }, } +json_schema = { + "index": { + "name": "user_index_json", + "prefix": "users_json", + "storage_type": "json", + }, + "fields": { + "tag": [ + {"name": "$.credit_score", "as_name": "credit_score"}, + {"name": "$.user", "as_name": "user"}, + ], + "text": [{"name": "$.job", "as_name": "job"}], + "numeric": [{"name": "$.age", "as_name": "age"}], + "vector": [ + { + "name": "$.user_embedding", + "as_name": "user_embedding", + "dims": 3, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + } + ], + }, +} -def test_simple(client): + +@pytest.mark.parametrize("schema", [hash_schema, json_schema]) +def test_simple(client, schema): index = SearchIndex.from_dict(schema) # assign client (only for testing) index.set_client(client) # create the index index.create(overwrite=True) - # load data into the index in Redis - index.load(data) + # Prepare and load the data based on storage type + def hash_preprocess(item: dict) -> dict: + return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + + if index.storage_type == StorageType.HASH: + index.load(data, preprocess=hash_preprocess) + else: + # Load the prepared data into the index + print("DATA", data, flush=True) + index.load(data) query = VectorQuery( vector=[0.1, 0.1, 0.5], @@ -80,6 +118,7 @@ def test_simple(client): # users = list(results.docs) # print(len(users)) users = [doc for doc in results.docs] + pprint(users) assert users[0].user in ["john", "mary"] assert users[1].user in ["john", "mary"] diff --git a/tests/sample_hash_schema.yaml b/tests/sample_hash_schema.yaml new file mode 100644 index 00000000..c4a603c3 --- /dev/null +++ b/tests/sample_hash_schema.yaml @@ -0,0 +1,14 @@ +index: + name: hash-test + prefix: hash + key_separator: ':' + storage_type: hash + +fields: + text: + - name: sentence + vector: + - name: embedding + dims: 768 + algorithm: flat + distance_metric: cosine \ No newline at end of file diff --git a/tests/sample_json_schema.yaml b/tests/sample_json_schema.yaml new file mode 100644 index 00000000..8f9fd564 --- /dev/null +++ b/tests/sample_json_schema.yaml @@ -0,0 +1,16 @@ +index: + name: json-test + prefix: json + key_separator: ':' + storage_type: json + +fields: + text: + - name: '$.sentence' + as_name: sentence + vector: + - name: '$.embedding' + as_name: embedding + dims: 768 + algorithm: flat + distance_metric: cosine \ No newline at end of file diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index 5fef85a5..42088bde 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -98,6 +98,9 @@ def test_numeric_filter(): nf = Num("numeric_field") <= 5 assert str(nf) == "@numeric_field:[-inf 5]" + nf = Num("numeric_field") > 5.5 + assert str(nf) == "@numeric_field:[-inf 5.5]" + nf = Num("numeric_field") <= None assert str(nf) == "*" @@ -130,10 +133,10 @@ def test_text_filter(): def test_geo_filter(): geo_f = Geo("geo_field") == GeoRadius(1.0, 2.0, 3, "km") - assert str(geo_f) == "@geo_field:[1.000000 2.000000 3 km]" + assert str(geo_f) == "@geo_field:[1.0 2.0 3 km]" geo_f = Geo("geo_field") != GeoRadius(1.0, 2.0, 3, "km") - assert str(geo_f) != "(-@geo_field:[1.000000 2.000000 3 m])" + assert str(geo_f) != "(-@geo_field:[1.0 2.0 3 m])" @pytest.mark.parametrize( @@ -215,8 +218,8 @@ def test_text_filter(operation, value, expected): @pytest.mark.parametrize( "operation, expected", [ - ("__eq__", "@geo_field:[1.000000 2.000000 3 km]"), - ("__ne__", "(-@geo_field:[1.000000 2.000000 3 km])"), + ("__eq__", "@geo_field:[1.0 2.0 3 km]"), + ("__ne__", "(-@geo_field:[1.0 2.0 3 km])"), ], ids=["eq", "ne"], ) diff --git a/tests/unit/test_index.py b/tests/unit/test_index.py index fb6a474f..40505a23 100644 --- a/tests/unit/test_index.py +++ b/tests/unit/test_index.py @@ -11,18 +11,18 @@ def test_search_index_get_key(): si = SearchIndex("my_index", fields=fields) key = si.key("foo") - assert key.startswith(si._prefix) + assert key.startswith(si.prefix) assert "foo" in key - key = si._create_key({"id": "foo"}) - assert key.startswith(si._prefix) + key = si._storage._create_key({"id": "foo"}) + assert key.startswith(si.prefix) assert "foo" not in key def test_search_index_no_prefix(): # specify None as the prefix... - si = SearchIndex("my_index", prefix=None, fields=fields) + si = SearchIndex("my_index", prefix="", fields=fields) key = si.key("foo") - assert not si._prefix + assert not si.prefix assert key == "foo" @@ -40,7 +40,7 @@ def test_search_index_create(client, redis_url): assert si.exists() assert "my_index" in convert_bytes(si.client.execute_command("FT._LIST")) - s1_2 = SearchIndex.from_existing("my_index", url=redis_url) + s1_2 = SearchIndex.from_existing("my_index", redis_url=redis_url) assert s1_2.info()["index_name"] == si.info()["index_name"] si.create(overwrite=False) @@ -133,7 +133,7 @@ async def test_async_search_index_load(async_client): def test_search_index_delete_nonexistent(client): si = SearchIndex("my_index", fields=fields) si.set_client(client) - with pytest.raises(redis.exceptions.ResponseError): + with pytest.raises(ValueError): si.delete() @@ -141,7 +141,7 @@ def test_search_index_delete_nonexistent(client): async def test_async_search_index_delete_nonexistent(async_client): asi = AsyncSearchIndex("my_index", fields=fields) asi.set_client(async_client) - with pytest.raises(redis.exceptions.ResponseError): + with pytest.raises(ValueError): await asi.delete() diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py new file mode 100644 index 00000000..8fc380fe --- /dev/null +++ b/tests/unit/test_query_types.py @@ -0,0 +1,56 @@ +import pytest +from redis.commands.search.document import Document +from redis.commands.search.query import Query +from redis.commands.search.result import Result + +from redisvl.index import process_results +from redisvl.query import CountQuery, FilterQuery, VectorQuery +from redisvl.query.filter import FilterExpression, Tag + +# Sample data for testing +sample_vector = [0.1, 0.2, 0.3, 0.4] + + +# Test Cases + + +def test_count_query(): + # Create a filter expression + filter_expression = Tag("brand") == "Nike" + count_query = CountQuery(filter_expression) + + # Check properties + assert isinstance(count_query.query, Query) + assert isinstance(count_query.params, dict) + assert count_query.params == {} + + fake_result = Result([2], "") + assert process_results(fake_result, count_query, "json") == 2 + + +def test_filter_query(): + # Create a filter expression + filter_expression = Tag("brand") == "Nike" + return_fields = ["brand", "price"] + filter_query = FilterQuery(return_fields, filter_expression, 10) + + # Check properties + assert filter_query._return_fields == return_fields + assert filter_query._num_results == 10 + assert filter_query.get_filter() == filter_expression + assert isinstance(filter_query.query, Query) + assert isinstance(filter_query.params, dict) + assert filter_query.params == {} + + +def test_vector_query(): + # Create a vector query + vector_query = VectorQuery(sample_vector, "vector_field", ["field1", "field2"]) + + # Check properties + assert vector_query._vector == sample_vector + assert vector_query._field == "vector_field" + assert "field1" in vector_query._return_fields + assert isinstance(vector_query.query, Query) + assert isinstance(vector_query.params, dict) + assert vector_query.params != {} diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 560fc08a..2e71a769 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,3 +1,5 @@ +import pathlib + import pytest from pydantic import ValidationError from redis.commands.search.field import ( @@ -12,15 +14,21 @@ FlatVectorField, GeoFieldSchema, HNSWVectorField, - MetadataSchemaGenerator, + IndexModel, NumericFieldSchema, + SchemaGenerator, SchemaModel, + StorageType, TagFieldSchema, TextFieldSchema, read_schema, ) +def get_base_path(): + return pathlib.Path(__file__).parent.resolve() + + # Utility functions to create schema instances with default values def create_text_field_schema(**kwargs): defaults = {"name": "example_textfield", "sortable": False, "weight": 1.0} @@ -143,15 +151,43 @@ def test_flat_vector_field_block_size_not_set(): assert "INITIAL_CAP" not in field_exported.args -# Test for schema model validation -def test_schema_model_validation_success(): - valid_index = {"name": "test_index", "storage_type": "hash"} - valid_fields = {"text": [create_text_field_schema()]} - schema_model = SchemaModel(index=valid_index, fields=valid_fields) +# Tests for IndexModel + + +def test_index_model_defaults(): + index = IndexModel(name="test_index") + assert index.name == "test_index" + assert index.prefix == "rvl" + assert index.key_separator == ":" + assert index.storage_type == StorageType.HASH + + +def test_index_model_custom_settings(): + index = IndexModel( + name="test_index", prefix="custom", key_separator="_", storage_type="json" + ) + assert index.name == "test_index" + assert index.prefix == "custom" + assert index.key_separator == "_" + assert index.storage_type == StorageType.JSON + + +def test_index_model_validation_errors(): + # Missing required field + with pytest.raises(ValueError): + IndexModel() + + # Invalid type + with pytest.raises(ValidationError): + IndexModel(name="test_index", prefix=None) + + # Invalid type + with pytest.raises(ValidationError): + IndexModel(name="test_index", key_separator=None) - assert schema_model.index.name == "test_index" - assert schema_model.index.storage_type == "hash" - assert len(schema_model.fields.text) == 1 + # Invalid type + with pytest.raises(ValidationError): + IndexModel(name="test_index", storage_type=None) def test_schema_model_validation_failures(): @@ -165,6 +201,20 @@ def test_schema_model_validation_failures(): SchemaModel(index={}, fields={}) +def test_read_hash_schema(): + hash_schema = read_schema( + str(get_base_path().joinpath("../sample_hash_schema.yaml")) + ) + assert hash_schema.index.name == "hash-test" + + +def test_read_json_schema(): + json_schema = read_schema( + str(get_base_path().joinpath("../sample_json_schema.yaml")) + ) + assert json_schema.index.name == "json-test" + + def test_read_schema_file_not_found(): with pytest.raises(FileNotFoundError): read_schema("non_existent_file.yaml") @@ -173,7 +223,7 @@ def test_read_schema_file_not_found(): # Fixture for the generator instance @pytest.fixture def schema_generator(): - return MetadataSchemaGenerator() + return SchemaGenerator() # Test cases for _test_numeric diff --git a/tests/unit/test_storage.py b/tests/unit/test_storage.py new file mode 100644 index 00000000..d0da9dbc --- /dev/null +++ b/tests/unit/test_storage.py @@ -0,0 +1,80 @@ +import pytest + +from redisvl.storage import BaseStorage, HashStorage, JsonStorage + + +@pytest.fixture(params=[JsonStorage, HashStorage]) +def storage_instance(request): + StorageClass = request.param + instance = StorageClass(prefix="test", key_separator=":") + return instance + + +def test_key_formatting(storage_instance): + key = "1234" + generated_key = storage_instance._key(key, "", "") + assert generated_key == key, "The generated key does not match the expected format." + generated_key = storage_instance._key(key, "", ":") + assert generated_key == key, "The generated key does not match the expected format." + generated_key = storage_instance._key(key, "test", ":") + assert ( + generated_key == f"test:{key}" + ), "The generated key does not match the expected format." + + +def test_create_key(storage_instance): + key_field = "id" + obj = {key_field: "1234"} + expected_key = ( + f"{storage_instance._prefix}{storage_instance._key_separator}{obj[key_field]}" + ) + generated_key = storage_instance._create_key(obj, key_field) + assert ( + generated_key == expected_key + ), "The generated key does not match the expected format." + + +def test_validate_success(storage_instance): + data = {"foo": "bar"} + try: + storage_instance._validate(data) + except Exception as e: + pytest.fail(f"_validate should not raise an exception here, but raised {e}") + + +def test_validate_failure(storage_instance): + data = "Some invalid data type" + with pytest.raises(TypeError): + storage_instance._validate(data) + data = 12345 + with pytest.raises(TypeError): + storage_instance._validate(data) + + +def test_preprocess(storage_instance): + data = {"key": "value"} + preprocessed_data = storage_instance._preprocess(preprocess=None, obj=data) + assert preprocessed_data == data + + def fn(d): + d["foo"] = "bar" + return d + + preprocessed_data = storage_instance._preprocess(fn, data) + assert "foo" in preprocessed_data + assert preprocessed_data["foo"] == "bar" + + +@pytest.mark.asyncio +async def test_preprocess(storage_instance): + data = {"key": "value"} + preprocessed_data = await storage_instance._apreprocess(preprocess=None, obj=data) + assert preprocessed_data == data + + async def fn(d): + d["foo"] = "bar" + return d + + preprocessed_data = await storage_instance._apreprocess(data, fn) + assert "foo" in preprocessed_data + assert preprocessed_data["foo"] == "bar"