Skip to content

Commit

Permalink
updates to schema and index and llmcache
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Dec 12, 2023
1 parent 07163d8 commit 7f536c7
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 52 deletions.
26 changes: 13 additions & 13 deletions redisvl/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from redis.commands.search.indexDefinition import IndexDefinition

from redisvl.query.query import BaseQuery, CountQuery, FilterQuery
from redisvl.schema import Schema, StorageType
from redisvl.schema import IndexSchema, StorageType
from redisvl.storage import HashStorage, JsonStorage
from redisvl.utils.connection import get_async_redis_connection, get_redis_connection
from redisvl.utils.utils import (
Expand Down Expand Up @@ -161,7 +161,7 @@ class SearchIndex:

def __init__(
self,
schema: Schema,
schema: IndexSchema,
redis_url: Optional[str] = None,
connection_args: Dict[str, Any] = {},
**kwargs,
Expand All @@ -170,7 +170,7 @@ def __init__(
redis_url, connection_args, and other kwargs.
"""
# final validation on schema object
if not schema or not isinstance(schema, Schema):
if not schema or not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid schema object")

# establish Redis connection
Expand All @@ -182,7 +182,7 @@ def __init__(
self.schema = schema

self._storage = self._STORAGE_MAP[self.schema.storage_type](
self.schema.index_prefix, self.schema.key_separator
self.schema.prefix, self.schema.key_separator
)

def set_client(self, client: redis.Redis) -> None:
Expand Down Expand Up @@ -250,7 +250,7 @@ def from_yaml(
Returns:
SearchIndex: A RedisVL SearchIndex object.
"""
schema = Schema.from_yaml(schema_path)
schema = IndexSchema.from_yaml(schema_path)
return cls(schema=schema, connection_args=connection_args, **kwargs)

@classmethod
Expand Down Expand Up @@ -285,7 +285,7 @@ def from_dict(
Returns:
SearchIndex: A RedisVL SearchIndex object.
"""
schema = Schema.from_dict(**schema_dict)
schema = IndexSchema.from_dict(**schema_dict)
return cls(schema=schema, connection_args=connection_args, **kwargs)

def connect(self, redis_url: Optional[str] = None, **kwargs):
Expand Down Expand Up @@ -319,7 +319,7 @@ def key(self, key_value: str) -> str:
Returns:
str: The full Redis key including key prefix and value as a string.
"""
return self._storage._key(key_value, self.schema.index_prefix, self.schema.key_separator)
return self._storage._key(key_value, self.schema.prefix, self.schema.key_separator)

@check_modules_present("_redis_conn")
def create(self, overwrite: bool = False) -> None:
Expand All @@ -334,8 +334,8 @@ def create(self, overwrite: bool = False) -> None:
ValueError: If no fields are defined for the index.
"""
# Check that fields are defined.
index_fields = self.schema.index_fields
if not index_fields:
redis_fields = self.schema.redis_fields
if not redis_fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
raise TypeError("overwrite must be of type bool")
Expand All @@ -349,7 +349,7 @@ def create(self, overwrite: bool = False) -> None:

# Create the index with the specified fields and settings.
self._redis_conn.sync.ft(self.name).create_index( # type: ignore
fields=index_fields,
fields=redis_fields,
definition=IndexDefinition(
prefix=[self.prefix], index_type=self._storage.type
),
Expand Down Expand Up @@ -487,8 +487,8 @@ async def acreate(self, overwrite: bool = False) -> None:
Raises:
RuntimeError: If the index already exists and 'overwrite' is False.
"""
index_fields = self.schema.index_fields
if not index_fields:
redis_fields = self.schema.redis_fields
if not redis_fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
raise TypeError("overwrite must be of type bool")
Expand All @@ -502,7 +502,7 @@ async def acreate(self, overwrite: bool = False) -> None:

# Create Index with proper IndexType
await self._redis_conn.a.ft(self.name).create_index( # type: ignore
fields=index_fields,
fields=redis_fields,
definition=IndexDefinition(
prefix=[self.prefix], index_type=self._storage.type
),
Expand Down
61 changes: 40 additions & 21 deletions redisvl/llmcache/semantic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import warnings
from typing import Any, Dict, List, Optional

from redis.commands.search.field import Field, VectorField
from typing import Any, Dict, List, Optional

from redisvl.index import SearchIndex
from redisvl.llmcache.base import BaseLLMCache
from redisvl.query import VectorQuery
from redisvl.utils.utils import array_to_buffer
from redisvl.schema import Schema
from redisvl.schema import IndexSchema, StorageType
from redisvl.vectorize.base import BaseVectorizer
from redisvl.vectorize.text import HFTextVectorizer


class LLMCacheSchema(Schema):
"""Schema for the LLMCache."""
class LLMCacheSchema(IndexSchema):
"""RedisVL index schema for the LLMCache."""
# User should not be able to change these for the default LLMCache
prompt_field_name: str = "prompt"
vector_field_name: str = "prompt_vector"
response_field_name: str = "response"
Expand All @@ -22,25 +22,36 @@ def __init__(
self,
name: str = "cache",
prefix: str = "llmcache",
key_separator: str = ":",
storage_type: str = "hash",
**data,
vector_dims: int = 768,
**kwargs,
):
super.__init__(**data)
# Construct the base base index schema
super().__init__(
name=name,
prefix=prefix,
**kwargs
)
# other schema kwargs will get consumed here
# otherwise fall back to index schema defaults

# Add fields specific to the LLMCacheSchema
self.add_field("text", name=self.prompt_field_name)
self.add_field("response", name=self.vector_field_name)
self.add_field("text", name=self.response_field_name)
self.add_field("vector",
name=self.vector_field_name,
dims=768,
dims=vector_dims,
datatype="float32",
distance_metric="cosine",
algorithm="flat"
)

class Config:
# ignore extra fields passed in kwargs
# Ignore extra fields passed in kwargs
ignore_extra = True

@property
def vector_field(self) -> Dict[str, Any]:
return self.fields["vector"][0]

class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""
Expand All @@ -55,7 +66,7 @@ def __init__(
"sentence-transformers/all-mpnet-base-v2"
),
redis_url: str = "redis://localhost:6379",
connection_args: Optional[dict] = None,
connection_args: Optional[dict] = {},
**kwargs,
):
"""Semantic Cache for Large Language Models.
Expand Down Expand Up @@ -100,23 +111,29 @@ def __init__(
stacklevel=2,
)

if not isinstance(name, str) or not isinstance(prefix, str):
raise ValueError("A valid index name and prefix must be provided.")

if name is None or prefix is None:
raise ValueError("Index name and prefix must be provided.")
self._schema = LLMCacheSchema(
name=name,
prefix=prefix,
vector_dims=vectorizer.dims,
**kwargs
)

self._schema = LLMCacheSchema(**kwargs)
# set other components
self.set_vectorizer(vectorizer)
self.set_ttl(ttl)
self.set_threshold(distance_threshold)

# build search index
self._index = SearchIndex(
schema=self._schema,
redis_url=redis_url,
connection_args = connection_args
connection_args=connection_args
)
self._index.create(overwrite=False)


@classmethod
def from_index(cls, index: SearchIndex, **kwargs):
"""DEPRECATED: Create a SemanticCache from a pre-existing SearchIndex.
Expand Down Expand Up @@ -177,11 +194,13 @@ def set_vectorizer(self, vectorizer: BaseVectorizer) -> None:
if not isinstance(vectorizer, BaseVectorizer):
raise TypeError("Must provide a valid redisvl.vectorizer class.")

if self._vector_field.get("dims") != vectorizer.dims:
schema_vector_dims = self._schema.vector_field.dims

if schema_vector_dims != vectorizer.dims:
raise ValueError(
"Invalid vector dimensions!"
"Invalid vector dimensions! "
f"Vectorizer has dims defined as {vectorizer.dims}",
f"Vector field has dims defined as {self._vector_field.get('dims')}"
f"Vector field has dims defined as {schema_vector_dims}"
)

self._vectorizer = vectorizer
Expand Down
2 changes: 1 addition & 1 deletion redisvl/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from redisvl.schema.schema import (
StorageType,
Schema,
IndexSchema,
)


Expand Down
39 changes: 22 additions & 17 deletions redisvl/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import yaml
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Union, Tuple, Optional
from typing import Any, Dict, List, Union, Tuple, Optional, Type

from pydantic import BaseModel, ValidationError

Expand Down Expand Up @@ -64,7 +64,7 @@ class IndexSchema(BaseModel):
}

@property
def index_fields(self) -> list:
def redis_fields(self) -> list:
"""Returns a list of index fields in the Redis database."""
redis_fields = []
for field_list in self.fields.values():
Expand All @@ -77,6 +77,10 @@ def add_fields(self, fields: Dict[str, List[Dict[str, Any]]]):
for field_data in field_list:
self.add_field(field_type, **field_data)

def _get_field_class(self, field_type: str) -> Type[BaseField]:
"""Return a field class given the named type"""
return self._FIELD_TYPE_MAP.get(field_type)

def _create_field_instance(
self,
field_name: str,
Expand Down Expand Up @@ -166,16 +170,25 @@ def generate_fields(
ignore_fields: List[str] = [],
field_args: Dict[str, Dict[str, Any]] = {}
) -> Dict[str, List[Dict[str, Any]]]:
"""_summary_
"""Generate metadata fields for an index schema by inferring types from
a sample of provided data. Fields can be ignored with ignore_fields
and customized with field_args. Error handling behavior can be
enforced with the strict flag.
Args:
data (Dict[str, Any]): _description_
strict (bool, optional): _description_. Defaults to False.
ignore_fields (List[str], optional): _description_. Defaults to [].
field_args (Dict[str, Dict[str, Any]], optional): _description_. Defaults to {}.
data (Dict[str, Any]): Sample data used to infer field types.
strict (bool, optional): If True, raises an error when a field type
can't be inferred. Defaults to False.
ignore_fields (List[str], optional): List of field names to ignore.
Defaults to [].
field_args (Dict[str, Dict[str, Any]], optional): Additional
arguments for each field. Defaults to {}.
Raises:
ValueError: If strict is True and a field type cannot be inferred.
Returns:
Dict[str, List[Dict[str, Any]]]: _description_
Dict[str, List[Dict[str, Any]]]: A dictionary of fields.
"""
fields = {}
for field_name, value in data.items():
Expand Down Expand Up @@ -261,7 +274,7 @@ def from_yaml(cls, file_path: str) -> "IndexSchema":
FileNotFoundError: If the YAML file does not exist.
"""
# Check file path
fp = cls._check_yaml_path(file_path)
fp = cls._check_yaml_path(cls, file_path)
if not fp.exists():
raise FileNotFoundError(f"Schema file {file_path} does not exist")

Expand Down Expand Up @@ -292,14 +305,6 @@ def _should_ignore_field(self, field_name: str, ignore_fields: List[str]) -> boo
return field_name in ignore_fields







import re
from typing import Any

class TypeInferrer:
"""
Infers the type of a field based on its value.
Expand Down
1 change: 1 addition & 0 deletions redisvl/vectorize/text/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
"OpenAI vectorizer requires the openai library. Please install with `pip install openai`"
)

# TODO: should read this from environment to prevent verbose UX
if not api_config or "api_key" not in api_config:
raise ValueError("OpenAI API key is required in api_config")

Expand Down

0 comments on commit 7f536c7

Please sign in to comment.