From 75e2af6f0ccb5c3c15a7ca27ec307938d4dd250a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 14 Jan 2025 22:50:21 -0500 Subject: [PATCH] add validation --- agixt/Memories.py | 166 +++++++++++++++++++++------------------------- 1 file changed, 74 insertions(+), 92 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 5ed5e906a6cb..fd5607340d78 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -730,59 +730,52 @@ async def get_memories_data( try: if DATABASE_TYPE == "postgresql": - # First convert the embedding to string format for binding - embedding_str = f"[{','.join(map(str, query_embedding))}]" - - stmt = text( - """ - WITH vector_matches AS ( - SELECT - m.*, - 1 - (m.embedding <=> %(embedding)s::vector) as similarity - FROM memory m - WHERE m.agent_id = %(agent_id)s::uuid - AND ( - CASE - WHEN %(conversation_id)s IS NULL THEN m.conversation_id IS NULL - WHEN %(conversation_id)s ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' THEN m.conversation_id = %(conversation_id)s::uuid - ELSE m.conversation_id IS NULL - END + try: + stmt = text( + """ + WITH vector_matches AS ( + SELECT + m.*, + 1 - (m.embedding <=> :embedding::vector) as similarity + FROM memory m + WHERE m.agent_id = :agent_id + AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL) + ORDER BY similarity DESC + LIMIT :limit ) - ORDER BY similarity DESC - LIMIT %(limit)s + SELECT + text, + external_source, + description, + additional_metadata, + timestamp, + similarity + FROM vector_matches + WHERE similarity >= :min_score; + """ ) - SELECT - text, - external_source, - description, - additional_metadata, - timestamp, - similarity - FROM vector_matches - WHERE similarity >= %(min_score)s; - """ - ) - try: + # Convert embedding to string representation + embedding_str = f"[{','.join(map(str, query_embedding))}]" + results = session.execute( stmt, { "embedding": embedding_str, - "agent_id": str(self.agent_id), - "conversation_id": ( - str(conversation_id) if conversation_id else None - ), + "agent_id": self.agent_id, + "conversation_id": conversation_id, "limit": limit, "min_score": min_relevance_score, }, ).fetchall() + except Exception as e: logging.warning( - f"Search failed, falling back to basic search: {e}" + f"Vector search failed, falling back to basic search: {e}" ) - session.rollback() # Important: rollback the failed transaction + session.rollback() - # Fall back to basic query + # Simpler fallback query basic_stmt = text( """ SELECT @@ -792,20 +785,18 @@ async def get_memories_data( additional_metadata, timestamp, 0.5 as similarity - FROM memory - WHERE agent_id = %(agent_id)s::uuid - AND (conversation_id IS NULL OR conversation_id = %(conversation_id)s::uuid) - LIMIT %(limit)s + FROM memory m + WHERE m.agent_id = :agent_id + AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL) + LIMIT :limit """ ) results = session.execute( basic_stmt, { - "agent_id": str(self.agent_id), - "conversation_id": ( - str(conversation_id) if conversation_id else None - ), + "agent_id": self.agent_id, + "conversation_id": conversation_id, "limit": limit, }, ).fetchall() @@ -964,59 +955,52 @@ async def get_memories( try: if DATABASE_TYPE == "postgresql": - # First convert the embedding to string format for binding - embedding_str = f"[{','.join(map(str, query_embedding))}]" - - stmt = text( - """ - WITH vector_matches AS ( - SELECT - m.*, - 1 - (m.embedding <=> %(embedding)s::vector) as similarity - FROM memory m - WHERE m.agent_id = %(agent_id)s::uuid - AND ( - CASE - WHEN %(conversation_id)s IS NULL THEN m.conversation_id IS NULL - WHEN %(conversation_id)s ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' THEN m.conversation_id = %(conversation_id)s::uuid - ELSE m.conversation_id IS NULL - END + try: + stmt = text( + """ + WITH vector_matches AS ( + SELECT + m.*, + 1 - (m.embedding <=> :embedding::vector) as similarity + FROM memory m + WHERE m.agent_id = :agent_id + AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL) + ORDER BY similarity DESC + LIMIT :limit ) - ORDER BY similarity DESC - LIMIT %(limit)s + SELECT + text, + external_source, + description, + additional_metadata, + timestamp, + similarity + FROM vector_matches + WHERE similarity >= :min_score; + """ ) - SELECT - text, - external_source, - description, - additional_metadata, - timestamp, - similarity - FROM vector_matches - WHERE similarity >= %(min_score)s; - """ - ) - try: + # Convert embedding to string representation + embedding_str = f"[{','.join(map(str, query_embedding))}]" + results = session.execute( stmt, { "embedding": embedding_str, - "agent_id": str(self.agent_id), - "conversation_id": ( - str(conversation_id) if conversation_id else None - ), + "agent_id": self.agent_id, + "conversation_id": conversation_id, "limit": limit, "min_score": min_relevance_score, }, ).fetchall() + except Exception as e: logging.warning( - f"Search failed, falling back to basic search: {e}" + f"Vector search failed, falling back to basic search: {e}" ) - session.rollback() # Important: rollback the failed transaction + session.rollback() - # Fall back to basic query + # Simpler fallback query basic_stmt = text( """ SELECT @@ -1026,20 +1010,18 @@ async def get_memories( additional_metadata, timestamp, 0.5 as similarity - FROM memory - WHERE agent_id = %(agent_id)s::uuid - AND (conversation_id IS NULL OR conversation_id = %(conversation_id)s::uuid) - LIMIT %(limit)s + FROM memory m + WHERE m.agent_id = :agent_id + AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL) + LIMIT :limit """ ) results = session.execute( basic_stmt, { - "agent_id": str(self.agent_id), - "conversation_id": ( - str(conversation_id) if conversation_id else None - ), + "agent_id": self.agent_id, + "conversation_id": conversation_id, "limit": limit, }, ).fetchall()