Skip to content

Commit

Permalink
add validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 15, 2025
1 parent d5cfac3 commit 75e2af6
Showing 1 changed file with 74 additions and 92 deletions.
166 changes: 74 additions & 92 deletions agixt/Memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,59 +730,52 @@ async def get_memories_data(

try:
if DATABASE_TYPE == "postgresql":
# First convert the embedding to string format for binding
embedding_str = f"[{','.join(map(str, query_embedding))}]"

stmt = text(
"""
WITH vector_matches AS (
SELECT
m.*,
1 - (m.embedding <=> %(embedding)s::vector) as similarity
FROM memory m
WHERE m.agent_id = %(agent_id)s::uuid
AND (
CASE
WHEN %(conversation_id)s IS NULL THEN m.conversation_id IS NULL
WHEN %(conversation_id)s ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' THEN m.conversation_id = %(conversation_id)s::uuid
ELSE m.conversation_id IS NULL
END
try:
stmt = text(
"""
WITH vector_matches AS (
SELECT
m.*,
1 - (m.embedding <=> :embedding::vector) as similarity
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
ORDER BY similarity DESC
LIMIT :limit
)
ORDER BY similarity DESC
LIMIT %(limit)s
SELECT
text,
external_source,
description,
additional_metadata,
timestamp,
similarity
FROM vector_matches
WHERE similarity >= :min_score;
"""
)
SELECT
text,
external_source,
description,
additional_metadata,
timestamp,
similarity
FROM vector_matches
WHERE similarity >= %(min_score)s;
"""
)

try:
# Convert embedding to string representation
embedding_str = f"[{','.join(map(str, query_embedding))}]"

results = session.execute(
stmt,
{
"embedding": embedding_str,
"agent_id": str(self.agent_id),
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"limit": limit,
"min_score": min_relevance_score,
},
).fetchall()

except Exception as e:
logging.warning(
f"Search failed, falling back to basic search: {e}"
f"Vector search failed, falling back to basic search: {e}"
)
session.rollback() # Important: rollback the failed transaction
session.rollback()

# Fall back to basic query
# Simpler fallback query
basic_stmt = text(
"""
SELECT
Expand All @@ -792,20 +785,18 @@ async def get_memories_data(
additional_metadata,
timestamp,
0.5 as similarity
FROM memory
WHERE agent_id = %(agent_id)s::uuid
AND (conversation_id IS NULL OR conversation_id = %(conversation_id)s::uuid)
LIMIT %(limit)s
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
LIMIT :limit
"""
)

results = session.execute(
basic_stmt,
{
"agent_id": str(self.agent_id),
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"limit": limit,
},
).fetchall()
Expand Down Expand Up @@ -964,59 +955,52 @@ async def get_memories(

try:
if DATABASE_TYPE == "postgresql":
# First convert the embedding to string format for binding
embedding_str = f"[{','.join(map(str, query_embedding))}]"

stmt = text(
"""
WITH vector_matches AS (
SELECT
m.*,
1 - (m.embedding <=> %(embedding)s::vector) as similarity
FROM memory m
WHERE m.agent_id = %(agent_id)s::uuid
AND (
CASE
WHEN %(conversation_id)s IS NULL THEN m.conversation_id IS NULL
WHEN %(conversation_id)s ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' THEN m.conversation_id = %(conversation_id)s::uuid
ELSE m.conversation_id IS NULL
END
try:
stmt = text(
"""
WITH vector_matches AS (
SELECT
m.*,
1 - (m.embedding <=> :embedding::vector) as similarity
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
ORDER BY similarity DESC
LIMIT :limit
)
ORDER BY similarity DESC
LIMIT %(limit)s
SELECT
text,
external_source,
description,
additional_metadata,
timestamp,
similarity
FROM vector_matches
WHERE similarity >= :min_score;
"""
)
SELECT
text,
external_source,
description,
additional_metadata,
timestamp,
similarity
FROM vector_matches
WHERE similarity >= %(min_score)s;
"""
)

try:
# Convert embedding to string representation
embedding_str = f"[{','.join(map(str, query_embedding))}]"

results = session.execute(
stmt,
{
"embedding": embedding_str,
"agent_id": str(self.agent_id),
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"limit": limit,
"min_score": min_relevance_score,
},
).fetchall()

except Exception as e:
logging.warning(
f"Search failed, falling back to basic search: {e}"
f"Vector search failed, falling back to basic search: {e}"
)
session.rollback() # Important: rollback the failed transaction
session.rollback()

# Fall back to basic query
# Simpler fallback query
basic_stmt = text(
"""
SELECT
Expand All @@ -1026,20 +1010,18 @@ async def get_memories(
additional_metadata,
timestamp,
0.5 as similarity
FROM memory
WHERE agent_id = %(agent_id)s::uuid
AND (conversation_id IS NULL OR conversation_id = %(conversation_id)s::uuid)
LIMIT %(limit)s
FROM memory m
WHERE m.agent_id = :agent_id
AND (m.conversation_id = :conversation_id OR m.conversation_id IS NULL)
LIMIT :limit
"""
)

results = session.execute(
basic_stmt,
{
"agent_id": str(self.agent_id),
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"agent_id": self.agent_id,
"conversation_id": conversation_id,
"limit": limit,
},
).fetchall()
Expand Down

0 comments on commit 75e2af6

Please sign in to comment.