Skip to content

Commit

Permalink
use memories
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 14, 2025
1 parent 9c1ec76 commit 2cfe266
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 238 deletions.
129 changes: 0 additions & 129 deletions agixt/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,132 +1100,3 @@ def get_commands_prompt(self, conversation_id):
- **THE ASSISTANT CANNOT EXECUTE A COMMAND THAT IS NOT ON THE LIST OF EXAMPLES!**"""
return agent_commands
return ""

async def write_text_to_memory(
self,
user_input: str,
text: str,
conversation_id: str = None,
external_source: str = "user input",
):
session = get_session()
try:
# Remove old content for external sources
if external_source.startswith(("file", "http://", "https://")):
session.query(Memory).filter_by(
agent_id=self.agent_id,
conversation_id=conversation_id,
external_source=external_source,
).delete()

# Split into chunks and embed
chunks = await self.chunk_content(text=text, chunk_size=self.chunk_size)
if self.summarize_content:
chunks = [await self.summarize_text(chunk) for chunk in chunks]

for chunk in chunks:
embedding = self.embeddings(chunk)
memory = Memory(
agent_id=self.agent_id,
conversation_id=conversation_id,
embedding=embedding,
text=chunk,
external_source=external_source,
description=user_input,
additional_metadata=chunk,
)
session.add(memory)

session.commit()
return True
except Exception as e:
session.rollback()
logging.error(f"Error writing to memory: {e}")
return False
finally:
session.close()

async def get_memories(
self,
user_input: str,
limit: int = 5,
min_relevance_score: float = 0.0,
conversation_id: str = None,
):
session = get_session()
DATABASE_TYPE = getenv("DATABASE_TYPE")
try:
query_embedding = self.embeddings(user_input)

if DATABASE_TYPE == "sqlite":
# SQLite VSS query
results = session.execute(
"""
SELECT m.*, distance
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
AND vss_memories MATCH :embedding
ORDER BY distance DESC
LIMIT :limit
""",
{
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"embedding": str(query_embedding.tolist()),
"limit": limit,
},
)
else:
# PostgreSQL vector query
results = session.execute(
"""
SELECT m.*,
1 - (m.embedding <=> :embedding::vector) as similarity
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
ORDER BY m.embedding <=> :embedding::vector
LIMIT :limit
""",
{
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"embedding": query_embedding.tolist(),
"limit": limit,
},
)

memories = []
for row in results:
score = row.distance if DATABASE_TYPE == "sqlite" else row.similarity
if score >= min_relevance_score:
memories.append(
{
"text": row.text,
"external_source_name": row.external_source,
"description": row.description,
"additional_metadata": row.additional_metadata,
"relevance_score": score,
}
)

return memories
finally:
session.close()

async def wipe_memory(self, conversation_id: str = None):
session = get_session()
try:
query = session.query(Memory).filter_by(agent_id=self.agent_id)
if conversation_id:
query = query.filter_by(conversation_id=conversation_id)
query.delete()
session.commit()
return True
except Exception as e:
session.rollback()
logging.error(f"Error wiping memory: {e}")
return False
finally:
session.close()
220 changes: 111 additions & 109 deletions agixt/Memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import os
import asyncio
import sys
import time
from DB import Memory, get_session, DATABASE_TYPE
import spacy
import chromadb
from chromadb.config import Settings
from chromadb.api.types import QueryResult
from numpy import array, linalg, ndarray
from hashlib import sha256
from Providers import Providers
from datetime import datetime
from collections import Counter
from typing import List
from Globals import getenv, DEFAULT_USER
Expand Down Expand Up @@ -295,12 +293,21 @@ def __init__(
self.summarize_content = summarize_content
self.failures = 0

async def wipe_memory(self):
async def wipe_memory(self, conversation_id: str = None):
session = get_session()
try:
self.chroma_client.delete_collection(name=self.collection_name)
query = session.query(Memory).filter_by(agent_id=self.agent_id)
if conversation_id:
query = query.filter_by(conversation_id=conversation_id)
query.delete()
session.commit()
return True
except:
except Exception as e:
session.rollback()
logging.error(f"Error wiping memory: {e}")
return False
finally:
session.close()

async def export_collection_to_json(self):
collection = await self.get_collection()
Expand Down Expand Up @@ -413,72 +420,45 @@ async def summarize_text(self, text: str) -> str:
async def write_text_to_memory(
self, user_input: str, text: str, external_source: str = "user input"
):
# Log the collection number and agent name
logging.info(f"Saving to collection name: {self.collection_name}")
collection = await self.get_collection()
if text:
if not isinstance(text, str):
text = str(text)
# Check for duplicates from external sources (files or URLs)
session = get_session()
chunks = await self.chunk_content(text=text, chunk_size=self.chunk_size)

try:
# Handle core memories vs conversation memories
conversation_id = (
None if self.collection_number == "0" else self.collection_number
)

# If replacing external source content, delete old entries
if external_source.startswith(("file", "http://", "https://")):
try:
# Get all external sources in this collection
existing_sources = await self.get_external_data_sources()

# Check if this source already exists in memory
for source in existing_sources:
if source == external_source:
logging.info(
f"Found existing content in memory from source: {external_source}"
)
# Delete existing memories for this source
await self.delete_memories_from_external_source(
external_source
)
logging.info(
f"Deleted existing content from source: {external_source}"
)
break
except Exception as e:
logging.warning(f"Error checking for existing content: {e}")
if self.summarize_content:
text = await self.summarize_text(text=text)
chunks = await self.chunk_content(text=text, chunk_size=self.chunk_size)
session.query(Memory).filter_by(
agent_id=self.agent_id,
conversation_id=conversation_id,
external_source=external_source,
).delete()

for chunk in chunks:
metadata = {
"timestamp": datetime.now().isoformat(),
"is_reference": str(False),
"external_source_name": external_source,
"description": user_input,
"additional_metadata": chunk,
"id": sha256(
(chunk + datetime.now().isoformat()).encode()
).hexdigest(),
}
try:
collection.add(
ids=metadata["id"],
metadatas=metadata,
documents=chunk,
)
except:
self.failures += 1
for i in range(5):
try:
time.sleep(0.1)
collection.add(
ids=metadata["id"],
metadatas=metadata,
documents=chunk,
)
self.failures = 0
break
except:
self.failures += 1
if self.failures > 5:
break
continue
return True
embedding = self.embedding_provider.embeddings(chunk)
memory = Memory(
agent_id=self.agent_id,
conversation_id=conversation_id,
embedding=embedding,
text=chunk,
external_source=external_source,
description=user_input,
additional_metadata=chunk,
)
session.add(memory)

session.commit()
return True

except Exception as e:
session.rollback()
logging.error(f"Error writing to memory: {e}")
return False
finally:
session.close()

async def get_memories_data(
self,
Expand Down Expand Up @@ -526,49 +506,71 @@ async def get_memories(
user_input: str,
limit: int,
min_relevance_score: float = 0.0,
) -> List[str]:
# If this is a conversation ID, update the collection name
if len(self.collection_number) > 4:
self.collection_name = normalize_collection_name(
user=self.user,
agent_name=self.agent_name,
collection_id=self.collection_number,
):
session = get_session()
try:
query_embedding = self.embedding_provider.embeddings(user_input)
conversation_id = (
None if self.collection_number == "0" else self.collection_number
)

logging.info(
f"Retrieving Memories from collection name: {self.collection_name}"
)
results = await self.get_memories_data(
user_input=user_input,
limit=limit,
min_relevance_score=min_relevance_score,
)
logging.info(f"{len(results)} user results found in {self.collection_name}")
if isinstance(results, str):
results = [results]
response = []
if results:
for result in results:
metadata = (
result["additional_metadata"]
if "additional_metadata" in result
else ""
)
external_source = (
result["external_source_name"]
if "external_source_name" in result
else None
if DATABASE_TYPE == "sqlite":
# SQLite VSS query with agent and conversation filtering
results = session.execute(
"""
SELECT m.*, distance
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
AND vss_memories MATCH :embedding
ORDER BY distance DESC
LIMIT :limit
""",
{
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"embedding": str(query_embedding.tolist()),
"limit": limit,
},
)
timestamp = (
result["timestamp"]
if "timestamp" in result
else datetime.now().strftime("%Y-%m-%d %H:%M:%S")
else:
# PostgreSQL vector query
results = session.execute(
"""
SELECT m.*,
1 - (m.embedding <=> :embedding::vector) as similarity
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
ORDER BY m.embedding <=> :embedding::vector
LIMIT :limit
""",
{
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"embedding": query_embedding.tolist(),
"limit": limit,
},
)
if external_source:
metadata = f"Sourced from {external_source}:\nSourced on: {timestamp}\n{metadata}"
if metadata not in response and metadata != "":
response.append(metadata)
return response

memories = []
for row in results:
score = row.distance if DATABASE_TYPE == "sqlite" else row.similarity
if score >= min_relevance_score:
memories.append(
{
"text": row.text,
"external_source_name": row.external_source,
"description": row.description,
"additional_metadata": row.additional_metadata,
"relevance_score": score,
}
)

return memories

finally:
session.close()

async def get_external_data_sources(self):
"""Get a list of all unique external source names from memory collection."""
Expand Down

0 comments on commit 2cfe266

Please sign in to comment.