Skip to content

Commit

Permalink
moves extension field names into constants file (#225)
Browse files Browse the repository at this point in the history
Cache, session, and router classes and their corresponding schema fields are hard coded. Since they must match between class and schema and not be modified they're moved to a constants file.
  • Loading branch information
justin-cechmanek authored Sep 30, 2024
1 parent 951d630 commit 51e58aa
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 131 deletions.
29 changes: 29 additions & 0 deletions redisvl/extensions/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Constants used within the extension classes SemanticCache, BaseSessionManager,
StandardSessionManager,SemanticSessionManager and SemanticRouter.
These constants are also used within theses classes corresponding schema.
"""

# BaseSessionManager
ID_FIELD_NAME: str = "entry_id"
ROLE_FIELD_NAME: str = "role"
CONTENT_FIELD_NAME: str = "content"
TOOL_FIELD_NAME: str = "tool_call_id"
TIMESTAMP_FIELD_NAME: str = "timestamp"
SESSION_FIELD_NAME: str = "session_tag"

# SemanticSessionManager
SESSION_VECTOR_FIELD_NAME: str = "vector_field"

# SemanticCache
REDIS_KEY_FIELD_NAME: str = "key"
ENTRY_ID_FIELD_NAME: str = "entry_id"
PROMPT_FIELD_NAME: str = "prompt"
RESPONSE_FIELD_NAME: str = "response"
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
INSERTED_AT_FIELD_NAME: str = "inserted_at"
UPDATED_AT_FIELD_NAME: str = "updated_at"
METADATA_FIELD_NAME: str = "metadata"

# SemanticRouter
ROUTE_VECTOR_FIELD_NAME: str = "vector"
17 changes: 12 additions & 5 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from pydantic.v1 import BaseModel, Field, root_validator, validator

from redisvl.extensions.constants import (
CACHE_VECTOR_FIELD_NAME,
INSERTED_AT_FIELD_NAME,
PROMPT_FIELD_NAME,
RESPONSE_FIELD_NAME,
UPDATED_AT_FIELD_NAME,
)
from redisvl.redis.utils import array_to_buffer, hashify
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp, deserialize, serialize
Expand Down Expand Up @@ -110,12 +117,12 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "prompt", "type": "text"},
{"name": "response", "type": "text"},
{"name": "inserted_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{"name": PROMPT_FIELD_NAME, "type": "text"},
{"name": RESPONSE_FIELD_NAME, "type": "text"},
{"name": INSERTED_AT_FIELD_NAME, "type": "numeric"},
{"name": UPDATED_AT_FIELD_NAME, "type": "numeric"},
{
"name": "prompt_vector",
"name": CACHE_VECTOR_FIELD_NAME,
"type": "vector",
"attrs": {
"dims": vector_dims,
Expand Down
65 changes: 30 additions & 35 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

from redis import Redis

from redisvl.extensions.constants import (
CACHE_VECTOR_FIELD_NAME,
ENTRY_ID_FIELD_NAME,
INSERTED_AT_FIELD_NAME,
METADATA_FIELD_NAME,
PROMPT_FIELD_NAME,
REDIS_KEY_FIELD_NAME,
RESPONSE_FIELD_NAME,
UPDATED_AT_FIELD_NAME,
)
from redisvl.extensions.llmcache.base import BaseLLMCache
from redisvl.extensions.llmcache.schema import (
CacheEntry,
Expand All @@ -19,15 +29,6 @@
class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""

redis_key_field_name: str = "key"
entry_id_field_name: str = "entry_id"
prompt_field_name: str = "prompt"
response_field_name: str = "response"
vector_field_name: str = "prompt_vector"
inserted_at_field_name: str = "inserted_at"
updated_at_field_name: str = "updated_at"
metadata_field_name: str = "metadata"

_index: SearchIndex
_aindex: Optional[AsyncSearchIndex] = None

Expand Down Expand Up @@ -94,12 +95,12 @@ def __init__(
# Process fields and other settings
self.set_threshold(distance_threshold)
self.return_fields = [
self.entry_id_field_name,
self.prompt_field_name,
self.response_field_name,
self.inserted_at_field_name,
self.updated_at_field_name,
self.metadata_field_name,
ENTRY_ID_FIELD_NAME,
PROMPT_FIELD_NAME,
RESPONSE_FIELD_NAME,
INSERTED_AT_FIELD_NAME,
UPDATED_AT_FIELD_NAME,
METADATA_FIELD_NAME,
]

# Create semantic cache schema and index
Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(

validate_vector_dims(
vectorizer.dims,
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
)
self._vectorizer = vectorizer

Expand All @@ -145,9 +146,7 @@ def _modify_schema(
"""Modify the base cache schema using the provided filterable fields"""

if filterable_fields is not None:
protected_field_names = set(
self.return_fields + [self.redis_key_field_name]
)
protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME])
for filter_field in filterable_fields:
field_name = filter_field["name"]
if field_name in protected_field_names:
Expand Down Expand Up @@ -300,7 +299,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
doesn't match the search index vector dimensions."""
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore
validate_vector_dims(len(vector), schema_vector_dims)

def check(
Expand Down Expand Up @@ -363,7 +362,7 @@ def check(

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
vector_field_name=CACHE_VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
Expand Down Expand Up @@ -444,7 +443,7 @@ async def acheck(

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
vector_field_name=CACHE_VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
Expand Down Expand Up @@ -479,7 +478,7 @@ def _process_cache_results(
cache_hit_dict = {
k: v for k, v in cache_hit_dict.items() if k in return_fields
}
cache_hit_dict[self.redis_key_field_name] = redis_key
cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key
cache_hits.append(cache_hit_dict)
return redis_keys, cache_hits

Expand Down Expand Up @@ -541,7 +540,7 @@ def store(
keys = self._index.load(
data=[cache_entry.to_dict()],
ttl=ttl,
id_field=self.entry_id_field_name,
id_field=ENTRY_ID_FIELD_NAME,
)
return keys[0]

Expand Down Expand Up @@ -605,7 +604,7 @@ async def astore(
keys = await aindex.load(
data=[cache_entry.to_dict()],
ttl=ttl,
id_field=self.entry_id_field_name,
id_field=ENTRY_ID_FIELD_NAME,
)
return keys[0]

Expand All @@ -629,21 +628,19 @@ def update(self, key: str, **kwargs) -> None:
for k, v in kwargs.items():

# Make sure the item is in the index schema
if k not in set(
self._index.schema.field_names + [self.metadata_field_name]
):
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
raise ValueError(f"{k} is not a valid field within the cache entry")

# Check for metadata and deserialize
if k == self.metadata_field_name:
if k == METADATA_FIELD_NAME:
if isinstance(v, dict):
kwargs[k] = serialize(v)
else:
raise TypeError(
"If specified, cached metadata must be a dictionary."
)

kwargs.update({self.updated_at_field_name: current_timestamp()})
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})

self._index.client.hset(key, mapping=kwargs) # type: ignore

Expand Down Expand Up @@ -674,21 +671,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
for k, v in kwargs.items():

# Make sure the item is in the index schema
if k not in set(
self._index.schema.field_names + [self.metadata_field_name]
):
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
raise ValueError(f"{k} is not a valid field within the cache entry")

# Check for metadata and deserialize
if k == self.metadata_field_name:
if k == METADATA_FIELD_NAME:
if isinstance(v, dict):
kwargs[k] = serialize(v)
else:
raise TypeError(
"If specified, cached metadata must be a dictionary."
)

kwargs.update({self.updated_at_field_name: current_timestamp()})
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})

await aindex.load(data=[kwargs], keys=[key])

Expand Down
3 changes: 2 additions & 1 deletion redisvl/extensions/router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic.v1 import BaseModel, Field, validator

from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.schema import IndexInfo, IndexSchema


Expand Down Expand Up @@ -104,7 +105,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
{
"name": "vector",
"name": ROUTE_VECTOR_FIELD_NAME,
"type": "vector",
"attrs": {
"algorithm": "flat",
Expand Down
5 changes: 3 additions & 2 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
from redis.exceptions import ResponseError

from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.extensions.router.schema import (
DistanceAggregationMethod,
Route,
Expand Down Expand Up @@ -226,7 +227,7 @@ def _classify_route(
"""Classify to a single route using a vector."""
vector_range_query = RangeQuery(
vector=vector,
vector_field_name="vector",
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
return_fields=["route_name"],
)
Expand Down Expand Up @@ -278,7 +279,7 @@ def _classify_multi_route(
"""Classify to multiple routes, up to max_k (int), using a vector."""
vector_range_query = RangeQuery(
vector=vector,
vector_field_name="vector",
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
return_fields=["route_name"],
)
Expand Down
17 changes: 8 additions & 9 deletions redisvl/extensions/session_manager/base_session.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Any, Dict, List, Optional, Union

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ROLE_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.extensions.session_manager.schema import ChatMessage
from redisvl.utils.utils import create_uuid


class BaseSessionManager:
id_field_name: str = "entry_id"
role_field_name: str = "role"
content_field_name: str = "content"
tool_field_name: str = "tool_call_id"
timestamp_field_name: str = "timestamp"
session_field_name: str = "session_tag"

def __init__(
self,
Expand Down Expand Up @@ -107,11 +106,11 @@ def _format_context(
context.append(chat_message.content)
else:
chat_message_dict = {
self.role_field_name: chat_message.role,
self.content_field_name: chat_message.content,
ROLE_FIELD_NAME: chat_message.role,
CONTENT_FIELD_NAME: chat_message.content,
}
if chat_message.tool_call_id is not None:
chat_message_dict[self.tool_field_name] = chat_message.tool_call_id
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id

context.append(chat_message_dict) # type: ignore

Expand Down
47 changes: 30 additions & 17 deletions redisvl/extensions/session_manager/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

from pydantic.v1 import BaseModel, Field, root_validator

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ID_FIELD_NAME,
ROLE_FIELD_NAME,
SESSION_FIELD_NAME,
SESSION_VECTOR_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.redis.utils import array_to_buffer
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp
Expand Down Expand Up @@ -31,18 +40,22 @@ class Config:
@root_validator(pre=True)
@classmethod
def generate_id(cls, values):
if "timestamp" not in values:
values["timestamp"] = current_timestamp()
if "entry_id" not in values:
values["entry_id"] = f'{values["session_tag"]}:{values["timestamp"]}'
if TIMESTAMP_FIELD_NAME not in values:
values[TIMESTAMP_FIELD_NAME] = current_timestamp()
if ID_FIELD_NAME not in values:
values[ID_FIELD_NAME] = (
f"{values[SESSION_FIELD_NAME]}:{values[TIMESTAMP_FIELD_NAME]}"
)
return values

def to_dict(self) -> Dict:
data = self.dict(exclude_none=True)

# handle optional fields
if "vector_field" in data:
data["vector_field"] = array_to_buffer(data["vector_field"])
if SESSION_VECTOR_FIELD_NAME in data:
data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer(
data[SESSION_VECTOR_FIELD_NAME]
)

return data

Expand All @@ -55,11 +68,11 @@ def from_params(cls, name: str, prefix: str):
return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "tag"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "tag"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": ROLE_FIELD_NAME, "type": "tag"},
{"name": CONTENT_FIELD_NAME, "type": "text"},
{"name": TOOL_FIELD_NAME, "type": "tag"},
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
{"name": SESSION_FIELD_NAME, "type": "tag"},
],
)

Expand All @@ -72,13 +85,13 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "tag"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "tag"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": ROLE_FIELD_NAME, "type": "tag"},
{"name": CONTENT_FIELD_NAME, "type": "text"},
{"name": TOOL_FIELD_NAME, "type": "tag"},
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
{"name": SESSION_FIELD_NAME, "type": "tag"},
{
"name": "vector_field",
"name": SESSION_VECTOR_FIELD_NAME,
"type": "vector",
"attrs": {
"dims": vectorizer_dims,
Expand Down
Loading

0 comments on commit 51e58aa

Please sign in to comment.