Skip to content

Commit

Permalink
Use selected database as vector database
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 14, 2025
1 parent 5fb84d8 commit 00424d7
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 0 deletions.
130 changes: 130 additions & 0 deletions agixt/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
UserOAuth,
OAuthProvider,
TaskItem,
Memory,
)
from Providers import Providers
from Extensions import Extensions
Expand Down Expand Up @@ -1083,3 +1084,132 @@ def get_commands_prompt(self, conversation_id):
- **THE ASSISTANT CANNOT EXECUTE A COMMAND THAT IS NOT ON THE LIST OF EXAMPLES!**"""
return agent_commands
return ""

async def write_text_to_memory(
self,
user_input: str,
text: str,
conversation_id: str = None,
external_source: str = "user input",
):
session = get_session()
try:
# Remove old content for external sources
if external_source.startswith(("file", "http://", "https://")):
session.query(Memory).filter_by(
agent_id=self.agent_id,
conversation_id=conversation_id,
external_source=external_source,
).delete()

# Split into chunks and embed
chunks = await self.chunk_content(text=text, chunk_size=self.chunk_size)
if self.summarize_content:
chunks = [await self.summarize_text(chunk) for chunk in chunks]

for chunk in chunks:
embedding = self.embeddings(chunk)
memory = Memory(
agent_id=self.agent_id,
conversation_id=conversation_id,
embedding=embedding,
text=chunk,
external_source=external_source,
description=user_input,
additional_metadata=chunk,
)
session.add(memory)

session.commit()
return True
except Exception as e:
session.rollback()
logging.error(f"Error writing to memory: {e}")
return False
finally:
session.close()

async def get_memories(
self,
user_input: str,
limit: int = 5,
min_relevance_score: float = 0.0,
conversation_id: str = None,
):
session = get_session()
DATABASE_TYPE = getenv("DATABASE_TYPE")
try:
query_embedding = self.embeddings(user_input)

if DATABASE_TYPE == "sqlite":
# SQLite VSS query
results = session.execute(
"""
SELECT m.*, distance
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
AND vss_memories MATCH :embedding
ORDER BY distance DESC
LIMIT :limit
""",
{
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"embedding": str(query_embedding.tolist()),
"limit": limit,
},
)
else:
# PostgreSQL vector query
results = session.execute(
"""
SELECT m.*,
1 - (m.embedding <=> :embedding::vector) as similarity
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
ORDER BY m.embedding <=> :embedding::vector
LIMIT :limit
""",
{
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"embedding": query_embedding.tolist(),
"limit": limit,
},
)

memories = []
for row in results:
score = row.distance if DATABASE_TYPE == "sqlite" else row.similarity
if score >= min_relevance_score:
memories.append(
{
"text": row.text,
"external_source_name": row.external_source,
"description": row.description,
"additional_metadata": row.additional_metadata,
"relevance_score": score,
}
)

return memories
finally:
session.close()

async def wipe_memory(self, conversation_id: str = None):
session = get_session()
try:
query = session.query(Memory).filter_by(agent_id=self.agent_id)
if conversation_id:
query = query.filter_by(conversation_id=conversation_id)
query.delete()
session.commit()
return True
except Exception as e:
session.rollback()
logging.error(f"Error wiping memory: {e}")
return False
finally:
session.close()
100 changes: 100 additions & 0 deletions agixt/DB.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
ForeignKey,
DateTime,
Boolean,
DDL,
event,
func,
)
from sqlalchemy.orm import sessionmaker, relationship, declarative_base
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.types import TypeDecorator, VARCHAR
from cryptography.fernet import Fernet
from Globals import getenv
import numpy as np

logging.basicConfig(
level=getenv("LOG_LEVEL"),
Expand Down Expand Up @@ -714,6 +718,102 @@ class Prompt(Base):
arguments = relationship("Argument", backref="prompt", cascade="all, delete-orphan")


class Memory(Base):
__tablename__ = "memory"
id = Column(
UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String,
primary_key=True,
default=get_new_id if DATABASE_TYPE == "sqlite" else uuid.uuid4,
)
agent_id = Column(
UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String,
ForeignKey("agent.id"),
nullable=False,
)
conversation_id = Column(
UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String,
ForeignKey("conversation.id"),
nullable=True,
) # Null for core memories
embedding = Column(Vector, nullable=False)
text = Column(Text, nullable=False)
external_source = Column(String, default="user input")
description = Column(Text)
timestamp = Column(DateTime, server_default=func.now())
additional_metadata = Column(Text)

agent = relationship("Agent", backref="memories")
conversation = relationship("Conversation", backref="memories")


class Vector(TypeDecorator):
impl = VARCHAR

def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
from sqlalchemy.dialects.postgresql import FLOAT
from sqlalchemy.sql.sqltypes import ARRAY

return dialect.type_descriptor(ARRAY(FLOAT))
return dialect.type_descriptor(VARCHAR)

def process_bind_param(self, value, dialect):
if dialect.name == "postgresql":
if isinstance(value, np.ndarray):
return value.tolist()
return value
# SQLite needs string representation
if value is not None:
if isinstance(value, np.ndarray):
return f'[{",".join(map(str, value.flatten()))}]'
elif isinstance(value, list):
return f'[{",".join(map(str, value))}]'
return value

def process_result_value(self, value, dialect):
if dialect.name == "postgresql":
return np.array(value) if value else None
if value is not None:
return np.array(eval(value))
return None


@event.listens_for(Memory.__table__, "after_create")
def create_vector_store(target, connection, **kw):
if DATABASE_TYPE == "sqlite":
connection.execute(
DDL(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS vss_memories
USING vss0(embedding(1536));
"""
)
)
else:
# For PostgreSQL, we need to create the vector extension and then convert our ARRAY column
connection.execute(DDL("CREATE EXTENSION IF NOT EXISTS vector;"))
connection.execute(
DDL(
"""
ALTER TABLE memory
ALTER COLUMN embedding
TYPE vector(1536)
USING embedding::vector(1536);
"""
)
)
connection.execute(
DDL(
"""
CREATE INDEX IF NOT EXISTS memory_embedding_idx
ON memory
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
)
)


def setup_default_roles():
with get_session() as db:
default_roles = [
Expand Down

0 comments on commit 00424d7

Please sign in to comment.