From cebc2ce9de3500c43264e4e47a7070f672e6135c Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:22:18 -0700 Subject: [PATCH] Feature/move logging (#1492) * move logging provider out * move logging provider to own directory, remove singleton * cleanup --- py/core/__init__.py | 5 +- py/core/base/__init__.py | 6 +- py/core/base/logging/__init__.py | 10 +- py/core/base/logging/r2r_logger.py | 836 +----------------- py/core/base/logging/run_manager.py | 4 +- py/core/base/pipeline/base_pipeline.py | 12 +- py/core/base/pipes/__init__.py | 4 +- py/core/base/pipes/base_pipe.py | 32 +- py/core/main/abstractions.py | 2 + py/core/main/assembly/builder.py | 17 +- py/core/main/assembly/factory.py | 49 +- py/core/main/config.py | 6 +- py/core/main/services/auth_service.py | 5 +- py/core/main/services/base.py | 5 +- py/core/main/services/ingestion_service.py | 12 +- py/core/main/services/kg_service.py | 5 +- py/core/main/services/management_service.py | 8 +- py/core/main/services/retrieval_service.py | 13 +- py/core/pipelines/graph_enrichment.py | 6 +- py/core/pipelines/rag_pipeline.py | 14 +- py/core/pipelines/search_pipeline.py | 14 +- py/core/pipes/abstractions/generator_pipe.py | 15 +- py/core/pipes/abstractions/search_pipe.py | 15 +- py/core/pipes/ingestion/embedding_pipe.py | 9 +- py/core/pipes/ingestion/parsing_pipe.py | 9 +- .../pipes/ingestion/vector_storage_pipe.py | 16 +- py/core/pipes/kg/clustering.py | 9 +- py/core/pipes/kg/community_summary.py | 9 +- py/core/pipes/kg/deduplication.py | 10 +- py/core/pipes/kg/deduplication_summary.py | 9 +- py/core/pipes/kg/entity_description.py | 9 +- py/core/pipes/kg/prompt_tuning.py | 9 +- py/core/pipes/kg/storage.py | 9 +- py/core/pipes/kg/triples_extraction.py | 9 +- py/core/pipes/retrieval/kg_search_pipe.py | 9 +- py/core/pipes/retrieval/multi_search.py | 3 +- .../pipes/retrieval/query_transform_pipe.py | 3 - py/core/pipes/retrieval/search_rag_pipe.py | 3 - py/core/pipes/retrieval/streaming_rag_pipe.py | 3 - py/core/pipes/retrieval/vector_search_pipe.py | 3 - py/core/providers/__init__.py | 3 + py/core/providers/database/collection.py | 2 +- py/core/providers/database/document.py | 2 +- py/core/providers/ingestion/r2r/base.py | 12 +- .../providers/ingestion/unstructured/base.py | 12 +- py/core/providers/logging/__init__.py | 3 + py/core/providers/logging/r2r_logging.py | 817 +++++++++++++++++ py/shared/abstractions/document.py | 6 +- py/tests/conftest.py | 6 +- .../core/pipelines/test_pipeline_logic.py | 5 +- .../database/relational/test_collection_db.py | 2 +- .../database/relational/test_document_db.py | 10 +- .../logging/test_logging_provider.py | 2 +- 53 files changed, 1028 insertions(+), 1080 deletions(-) create mode 100644 py/core/providers/logging/__init__.py create mode 100644 py/core/providers/logging/r2r_logging.py diff --git a/py/core/__init__.py b/py/core/__init__.py index 5e6baafcd..9d8917895 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -110,9 +110,7 @@ "LogFilterCriteria", "LogProcessor", # Logging Providers - "SqlitePersistentLoggingProvider", - "LoggingConfig", - "R2RLoggingProvider", + "PersistentLoggingConfig", # Run Manager "RunManager", "manage_run", @@ -125,7 +123,6 @@ ## PIPES "AsyncPipe", "AsyncState", - "PipeType", ## PROVIDERS # Base provider classes "AppConfig", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 56ec03571..11a4e958c 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -82,10 +82,7 @@ "LogAnalyticsConfig", "LogFilterCriteria", "LogProcessor", - # Logging Providers - "SqlitePersistentLoggingProvider", - "LoggingConfig", - "R2RLoggingProvider", + "PersistentLoggingConfig", # Run Manager "RunManager", "manage_run", @@ -98,7 +95,6 @@ ## PIPES "AsyncPipe", "AsyncState", - "PipeType", ## PROVIDERS # Base provider classes "AppConfig", diff --git a/py/core/base/logging/__init__.py b/py/core/base/logging/__init__.py index fa62a04c0..6a01fc33c 100644 --- a/py/core/base/logging/__init__.py +++ b/py/core/base/logging/__init__.py @@ -6,11 +6,7 @@ LogFilterCriteria, LogProcessor, ) -from .r2r_logger import ( - LoggingConfig, - R2RLoggingProvider, - SqlitePersistentLoggingProvider, -) +from .r2r_logger import PersistentLoggingConfig from .run_manager import RunManager, manage_run __all__ = [ @@ -22,9 +18,7 @@ "LogFilterCriteria", "LogProcessor", # Logging Providers - "SqlitePersistentLoggingProvider", - "LoggingConfig", - "R2RLoggingProvider", + "PersistentLoggingConfig", # Run Manager "RunManager", "manage_run", diff --git a/py/core/base/logging/r2r_logger.py b/py/core/base/logging/r2r_logger.py index e79ac493d..c8ec382b8 100644 --- a/py/core/base/logging/r2r_logger.py +++ b/py/core/base/logging/r2r_logger.py @@ -1,7 +1,4 @@ -import json import logging -import os -import uuid from abc import abstractmethod from datetime import datetime from typing import Any, Optional, Tuple, Union @@ -24,7 +21,7 @@ class RunInfoLog(BaseModel): user_id: UUID -class LoggingConfig(ProviderConfig): +class PersistentLoggingConfig(ProviderConfig): provider: str = "local" log_table: str = "logs" log_info_table: str = "log_info" @@ -38,7 +35,7 @@ def supported_providers(self) -> list[str]: return ["local", "postgres"] -class RunLoggingProvider(Provider): +class PersistentLoggingProvider(Provider): @abstractmethod async def close(self): pass @@ -78,832 +75,3 @@ async def get_info_logs( user_ids: Optional[list[UUID]] = None, ) -> list[RunInfoLog]: pass - - -class SqlitePersistentLoggingProvider(RunLoggingProvider): - def __init__(self, config: LoggingConfig): - self.log_table = config.log_table - self.log_info_table = config.log_info_table - # TODO - Should we re-consider this naming convention? - self.project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default") - self.logging_path = config.logging_path or os.getenv( - "LOCAL_DB_PATH", "local.sqlite" - ) - if not self.logging_path: - raise ValueError( - "Please set the environment variable LOCAL_DB_PATH." - ) - self.conn = None - try: - import aiosqlite - - self.aiosqlite = aiosqlite - except ImportError: - raise ImportError( - "Please install aiosqlite to use the SqlitePersistentLoggingProvider." - ) - - async def _init(self): - self.conn = await self.aiosqlite.connect(self.logging_path) - - await self.conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.project_name}_{self.log_table} ( - timestamp DATETIME, - run_id TEXT, - key TEXT, - value TEXT - ) - """ - ) - await self.conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.project_name}_{self.log_info_table} ( - timestamp DATETIME, - run_id TEXT UNIQUE, - run_type TEXT, - user_id TEXT - ) - """ - ) - await self.conn.executescript( - """ - CREATE TABLE IF NOT EXISTS conversations ( - id TEXT PRIMARY KEY, - created_at REAL - ); - - CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - parent_id TEXT, - content TEXT, - created_at REAL, - metadata TEXT, - FOREIGN KEY (conversation_id) REFERENCES conversations(id), - FOREIGN KEY (parent_id) REFERENCES messages(id) - ); - - CREATE TABLE IF NOT EXISTS branches ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - branch_point_id TEXT, - created_at REAL, - FOREIGN KEY (conversation_id) REFERENCES conversations(id), - FOREIGN KEY (branch_point_id) REFERENCES messages(id) - ); - - CREATE TABLE IF NOT EXISTS message_branches ( - message_id TEXT, - branch_id TEXT, - PRIMARY KEY (message_id, branch_id), - FOREIGN KEY (message_id) REFERENCES messages(id), - FOREIGN KEY (branch_id) REFERENCES branches(id) - ); - """ - ) - await self.conn.commit() - - async def __aenter__(self): - if self.conn is None: - await self._init() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - async def close(self): - if self.conn: - await self.conn.close() - self.conn = None - - async def log( - self, - run_id: UUID, - key: str, - value: str, - ): - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - await self.conn.execute( - f""" - INSERT INTO {self.project_name}_{self.log_table} (timestamp, run_id, key, value) - VALUES (datetime('now'), ?, ?, ?) - """, - (str(run_id), key, value), - ) - await self.conn.commit() - - async def info_log( - self, - run_id: UUID, - run_type: RunType, - user_id: UUID, - ): - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - await self.conn.execute( - f""" - INSERT INTO {self.project_name}_{self.log_info_table} (timestamp, run_id, run_type, user_id) - VALUES (datetime('now'), ?, ?, ?) - ON CONFLICT(run_id) DO UPDATE SET - timestamp = datetime('now'), - run_type = excluded.run_type, - user_id = excluded.user_id - """, - (str(run_id), run_type, str(user_id)), - ) - await self.conn.commit() - - async def get_info_logs( - self, - offset: int = 0, - limit: int = 100, - run_type_filter: Optional[RunType] = None, - user_ids: Optional[list[UUID]] = None, - ) -> list[RunInfoLog]: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - cursor = await self.conn.cursor() - query = "SELECT run_id, run_type, timestamp, user_id" - query += f" FROM {self.project_name}_{self.log_info_table}" - conditions = [] - params = [] - if run_type_filter: - conditions.append("run_type = ?") - params.append(run_type_filter) - if user_ids: - conditions.append(f"user_id IN ({','.join(['?']*len(user_ids))})") - params.extend([str(user_id) for user_id in user_ids]) - if conditions: - query += " WHERE " + " AND ".join(conditions) - query += " ORDER BY timestamp DESC LIMIT ? OFFSET ?" - params.extend([limit, offset]) - await cursor.execute(query, params) - rows = await cursor.fetchall() - return [ - RunInfoLog( - run_id=UUID(row[0]), - run_type=row[1], - timestamp=datetime.fromisoformat(row[2]), - user_id=UUID(row[3]), - ) - for row in rows - ] - - async def create_conversation(self) -> str: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - conversation_id = str(uuid.uuid4()) - created_at = datetime.utcnow().timestamp() - - await self.conn.execute( - "INSERT INTO conversations (id, created_at) VALUES (?, ?)", - (conversation_id, created_at), - ) - await self.conn.commit() - return conversation_id - - async def get_conversations_overview( - self, - conversation_ids: Optional[list[UUID]] = None, - offset: int = 0, - limit: int = -1, - ) -> dict[str, Union[list[dict], int]]: - """Get an overview of conversations, optionally filtered by conversation IDs, with pagination.""" - query = """ - WITH conversation_overview AS ( - SELECT c.id, c.created_at - FROM conversations c - {where_clause} - ), - counted_overview AS ( - SELECT *, COUNT(*) OVER() AS total_entries - FROM conversation_overview - ) - SELECT * FROM counted_overview - ORDER BY created_at DESC - LIMIT ? OFFSET ? - """ - - where_clause = ( - f"WHERE c.id IN ({','.join(['?' for _ in conversation_ids])})" - if conversation_ids - else "" - ) - query = query.format(where_clause=where_clause) - - params: list = [] - if conversation_ids: - params.extend(conversation_ids) - params.extend((limit if limit != -1 else -1, offset)) - - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - async with self.conn.execute(query, params) as cursor: - results = await cursor.fetchall() - - if not results: - logger.info("No conversations found.") - return {"results": [], "total_entries": 0} - - conversations = [ - { - "conversation_id": row[0], - "created_at": row[1], - } - for row in results - ] - - total_entries = results[0][-1] if results else 0 - - return {"results": conversations, "total_entries": total_entries} - - async def add_message( - self, - conversation_id: str, - content: Message, - parent_id: Optional[str] = None, - metadata: Optional[dict] = None, - ) -> str: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - message_id = str(uuid.uuid4()) - created_at = datetime.utcnow().timestamp() - - await self.conn.execute( - "INSERT INTO messages (id, conversation_id, parent_id, content, created_at, metadata) VALUES (?, ?, ?, ?, ?, ?)", - ( - message_id, - conversation_id, - parent_id, - content.json(), - created_at, - json.dumps(metadata or {}), - ), - ) - - if parent_id is not None: - await self.conn.execute( - """ - INSERT INTO message_branches (message_id, branch_id) - SELECT ?, branch_id FROM message_branches WHERE message_id = ? - """, - (message_id, parent_id), - ) - else: - # For messages with no parent, use the most recent branch, or create a new one - async with self.conn.execute( - """ - SELECT id FROM branches - WHERE conversation_id = ? - ORDER BY created_at DESC - LIMIT 1 - """, - (conversation_id,), - ) as cursor: - row = await cursor.fetchone() - if row is not None: - branch_id = row[0] - else: - # Create a new branch if none exists - branch_id = str(uuid.uuid4()) - await self.conn.execute( - """ - INSERT INTO branches (id, conversation_id, branch_point_id) VALUES (?, ?, NULL) - """, - (branch_id, conversation_id), - ) - await self.conn.execute( - """ - INSERT INTO message_branches (message_id, branch_id) VALUES (?, ?) - """, - (message_id, branch_id), - ) - - await self.conn.commit() - return message_id - - async def edit_message( - self, message_id: str, new_content: str - ) -> Tuple[str, str]: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - # Get the original message details - async with self.conn.execute( - "SELECT conversation_id, parent_id, content FROM messages WHERE id = ?", - (message_id,), - ) as cursor: - row = await cursor.fetchone() - if row is None: - raise ValueError(f"Message {message_id} not found") - conversation_id, parent_id, old_content_json = row - # Parse the old content to get the original Message object - old_message = Message.parse_raw(old_content_json) - - # Create a new Message object with the updated content - edited_message = Message( - role=old_message.role, - content=new_content, - name=old_message.name, - function_call=old_message.function_call, - tool_calls=old_message.tool_calls, - ) - - # Create a new branch - new_branch_id = str(uuid.uuid4()) - created_at = datetime.utcnow().timestamp() - await self.conn.execute( - "INSERT INTO branches (id, conversation_id, branch_point_id, created_at) VALUES (?, ?, ?, ?)", - (new_branch_id, conversation_id, message_id, created_at), - ) - - # Add the edited message with the same parent_id - new_message_id = str(uuid.uuid4()) - message_created_at = datetime.utcnow().timestamp() - await self.conn.execute( - "INSERT INTO messages (id, conversation_id, parent_id, content, created_at, metadata) VALUES (?, ?, ?, ?, ?, ?)", - ( - new_message_id, - conversation_id, - parent_id, - edited_message.json(), - message_created_at, - json.dumps({"edited": True}), - ), - ) - # Link the new message to the new branch - await self.conn.execute( - "INSERT INTO message_branches (message_id, branch_id) VALUES (?, ?)", - (new_message_id, new_branch_id), - ) - - # Link ancestor messages (excluding the original message) to the new branch - await self.conn.execute( - """ - WITH RECURSIVE ancestors(id) AS ( - SELECT parent_id FROM messages WHERE id = ? - UNION ALL - SELECT m.parent_id FROM messages m JOIN ancestors a ON m.id = a.id WHERE m.parent_id IS NOT NULL - ) - INSERT OR IGNORE INTO message_branches (message_id, branch_id) - SELECT id, ? FROM ancestors WHERE id IS NOT NULL - """, - (message_id, new_branch_id), - ) - - # Update the parent_id of the edited message's descendants in the new branch - await self.conn.execute( - """ - WITH RECURSIVE descendants(id) AS ( - SELECT id FROM messages WHERE parent_id = ? - UNION ALL - SELECT m.id FROM messages m JOIN descendants d ON m.parent_id = d.id - ) - UPDATE messages - SET parent_id = ? - WHERE id IN (SELECT id FROM descendants) - """, - (message_id, new_message_id), - ) - - await self.conn.commit() - return new_message_id, new_branch_id - - async def get_conversation( - self, conversation_id: str, branch_id: Optional[str] = None - ) -> Tuple[str, list[Message]]: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - if branch_id is None: - # Get the most recent branch by created_at timestamp - async with self.conn.execute( - """ - SELECT id FROM branches - WHERE conversation_id = ? - ORDER BY created_at DESC - LIMIT 1 - """, - (conversation_id,), - ) as cursor: - row = await cursor.fetchone() - branch_id = row[0] if row else None - - if branch_id is None: - return [] # No branches found for the conversation - - # Get all messages for this branch - async with self.conn.execute( - """ - WITH RECURSIVE branch_messages(id, content, parent_id, depth, created_at) AS ( - SELECT m.id, m.content, m.parent_id, 0, m.created_at - FROM messages m - JOIN message_branches mb ON m.id = mb.message_id - WHERE mb.branch_id = ? AND m.parent_id IS NULL - UNION - SELECT m.id, m.content, m.parent_id, bm.depth + 1, m.created_at - FROM messages m - JOIN message_branches mb ON m.id = mb.message_id - JOIN branch_messages bm ON m.parent_id = bm.id - WHERE mb.branch_id = ? - ) - SELECT id, content, parent_id FROM branch_messages - ORDER BY created_at ASC - """, - (branch_id, branch_id), - ) as cursor: - rows = await cursor.fetchall() - return [(row[0], Message.parse_raw(row[1])) for row in rows] - - async def get_branches_overview(self, conversation_id: str) -> list[dict]: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - async with self.conn.execute( - """ - SELECT b.id, b.branch_point_id, m.content, b.created_at - FROM branches b - LEFT JOIN messages m ON b.branch_point_id = m.id - WHERE b.conversation_id = ? - ORDER BY b.created_at - """, - (conversation_id,), - ) as cursor: - rows = await cursor.fetchall() - return [ - { - "branch_id": row[0], - "branch_point_id": row[1], - "content": row[2], - "created_at": row[3], - } - for row in rows - ] - - async def get_next_branch(self, current_branch_id: str) -> Optional[str]: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - async with self.conn.execute( - """ - SELECT id FROM branches - WHERE conversation_id = (SELECT conversation_id FROM branches WHERE id = ?) - AND created_at > (SELECT created_at FROM branches WHERE id = ?) - ORDER BY created_at - LIMIT 1 - """, - (current_branch_id, current_branch_id), - ) as cursor: - row = await cursor.fetchone() - return row[0] if row else None - - async def get_prev_branch(self, current_branch_id: str) -> Optional[str]: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - async with self.conn.execute( - """ - SELECT id FROM branches - WHERE conversation_id = (SELECT conversation_id FROM branches WHERE id = ?) - AND created_at < (SELECT created_at FROM branches WHERE id = ?) - ORDER BY created_at DESC - LIMIT 1 - """, - (current_branch_id, current_branch_id), - ) as cursor: - row = await cursor.fetchone() - return row[0] if row else None - - async def branch_at_message(self, message_id: str) -> str: - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - # Get the conversation_id of the message - async with self.conn.execute( - "SELECT conversation_id FROM messages WHERE id = ?", - (message_id,), - ) as cursor: - row = await cursor.fetchone() - if row is None: - raise ValueError(f"Message {message_id} not found") - conversation_id = row[0] - - # Check if the message is already a branch point - async with self.conn.execute( - "SELECT id FROM branches WHERE branch_point_id = ?", - (message_id,), - ) as cursor: - row = await cursor.fetchone() - if row is not None: - return row[0] # Return the existing branch ID - - # Create a new branch starting from message_id - new_branch_id = str(uuid.uuid4()) - await self.conn.execute( - "INSERT INTO branches (id, conversation_id, branch_point_id) VALUES (?, ?, ?)", - (new_branch_id, conversation_id, message_id), - ) - - # Link ancestor messages to the new branch - await self.conn.execute( - """ - WITH RECURSIVE ancestors(id) AS ( - SELECT id FROM messages WHERE id = ? - UNION ALL - SELECT m.parent_id FROM messages m JOIN ancestors a ON m.id = a.id WHERE m.parent_id IS NOT NULL - ) - INSERT OR IGNORE INTO message_branches (message_id, branch_id) - SELECT id, ? FROM ancestors - """, - (message_id, new_branch_id), - ) - - await self.conn.commit() - return new_branch_id - - async def delete_conversation(self, conversation_id: str): - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - # Begin a transaction - async with self.conn.execute("BEGIN TRANSACTION"): - # Delete all message branches associated with the conversation - await self.conn.execute( - "DELETE FROM message_branches WHERE message_id IN (SELECT id FROM messages WHERE conversation_id = ?)", - (conversation_id,), - ) - # Delete all branches associated with the conversation - await self.conn.execute( - "DELETE FROM branches WHERE conversation_id = ?", - (conversation_id,), - ) - # Delete all messages associated with the conversation - await self.conn.execute( - "DELETE FROM messages WHERE conversation_id = ?", - (conversation_id,), - ) - # Finally, delete the conversation itself - await self.conn.execute( - "DELETE FROM conversations WHERE id = ?", (conversation_id,) - ) - # Commit the transaction - await self.conn.commit() - - async def get_logs( - self, - run_ids: list[UUID], - limit_per_run: int = 10, - ) -> list: - if not run_ids: - raise ValueError("No run ids provided.") - if not self.conn: - raise ValueError( - "Initialize the connection pool before attempting to log." - ) - - cursor = await self.conn.cursor() - placeholders = ",".join(["?" for _ in run_ids]) - query = f""" - SELECT run_id, key, value, timestamp - FROM {self.project_name}_{self.log_table} - WHERE run_id IN ({placeholders}) - ORDER BY timestamp DESC - """ - - params = [str(run_id) for run_id in run_ids] - - await cursor.execute(query, params) - rows = await cursor.fetchall() - - # Post-process the results to limit per run_id and ensure only requested run_ids are included - result = [] - run_id_count = {str(run_id): 0 for run_id in run_ids} - for row in rows: - row_dict = dict(zip([d[0] for d in cursor.description], row)) - row_run_id = row_dict["run_id"] - if ( - row_run_id in run_id_count - and run_id_count[row_run_id] < limit_per_run - ): - row_dict["run_id"] = UUID(row_dict["run_id"]) - result.append(row_dict) - run_id_count[row_run_id] += 1 - return result - - -class HatchetLogger: - def __init__(self, context: Any): - self.context = context - - def _log(self, level: str, message: str, function: Optional[str] = None): - if function: - log_message = f"[{level}]: {function}: {message}" - else: - log_message = f"[{level}]: {message}" - self.context.log(log_message) - - def debug(self, message: str, function: Optional[str] = None): - self._log("DEBUG", message, function) - - def info(self, message: str, function: Optional[str] = None): - self._log("INFO", message, function) - - def warning(self, message: str, function: Optional[str] = None): - self._log("WARNING", message, function) - - def error(self, message: str, function: Optional[str] = None): - self._log("ERROR", message, function) - - def critical(self, message: str, function: Optional[str] = None): - self._log("CRITICAL", message, function) - - -class R2RLoggingProvider: - _instance = None - _is_configured = False - _config: Optional[LoggingConfig] = None - - PERSISTENT_PROVIDERS = { - "r2r": SqlitePersistentLoggingProvider, - # TODO - Mark this as deprecated - "local": SqlitePersistentLoggingProvider, - } - - @classmethod - def get_persistent_logger(cls): - return cls.PERSISTENT_PROVIDERS[cls._config.provider](cls._config) - - @classmethod - def configure(cls, logging_config: LoggingConfig): - if logging_config.provider == "local": - logger.warning( - "Local logging provider is deprecated. Please use 'r2r' instead." - ) - if not cls._is_configured: - cls._config = logging_config - cls._is_configured = True - else: - raise Exception("R2RLoggingProvider is already configured.") - - @classmethod - async def log( - cls, - run_id: UUID, - key: str, - value: str, - ): - try: - async with cls.get_persistent_logger() as provider: - await provider.log(run_id, key, value) - except Exception as e: - logger.error(f"Error logging data {(run_id, key, value)}: {e}") - - @classmethod - async def info_log( - cls, - run_id: UUID, - run_type: RunType, - user_id: UUID, - ): - try: - async with cls.get_persistent_logger() as provider: - await provider.info_log(run_id, run_type, user_id) - except Exception as e: - logger.error( - f"Error logging info data {(run_id, run_type, user_id)}: {e}" - ) - - @classmethod - async def get_info_logs( - cls, - offset: int = 0, - limit: int = 100, - run_type_filter: Optional[RunType] = None, - user_ids: Optional[list[UUID]] = None, - ) -> list[RunInfoLog]: - async with cls.get_persistent_logger() as provider: - return await provider.get_info_logs( - offset=offset, - limit=limit, - run_type_filter=run_type_filter, - user_ids=user_ids, - ) - - @classmethod - async def get_logs( - cls, - run_ids: list[UUID], - limit_per_run: int = 10, - ) -> list: - async with cls.get_persistent_logger() as provider: - return await provider.get_logs(run_ids, limit_per_run) - - @classmethod - async def create_conversation(cls) -> str: - async with cls.get_persistent_logger() as provider: - return await provider.create_conversation() - - @classmethod - async def get_conversations_overview( - cls, - conversation_ids: Optional[list[UUID]] = None, - offset: int = 0, - limit: int = 100, - ) -> list[dict]: - async with cls.get_persistent_logger() as provider: - return await provider.get_conversations_overview( - conversation_ids=conversation_ids, - offset=offset, - limit=limit, - ) - - @classmethod - async def add_message( - cls, - conversation_id: str, - content: Message, - parent_id: Optional[str] = None, - metadata: Optional[dict] = None, - ) -> str: - async with cls.get_persistent_logger() as provider: - return await provider.add_message( - conversation_id, content, parent_id, metadata - ) - - @classmethod - async def edit_message( - cls, message_id: str, new_content: str - ) -> Tuple[str, str]: - async with cls.get_persistent_logger() as provider: - return await provider.edit_message(message_id, new_content) - - @classmethod - async def get_conversation( - cls, conversation_id: str, branch_id: Optional[str] = None - ) -> list[dict]: - async with cls.get_persistent_logger() as provider: - return await provider.get_conversation(conversation_id, branch_id) - - @classmethod - async def get_branches_overview(cls, conversation_id: str) -> list[dict]: - async with cls.get_persistent_logger() as provider: - return await provider.get_branches_overview(conversation_id) - - @classmethod - async def get_next_branch(cls, current_branch_id: str) -> Optional[str]: - async with cls.get_persistent_logger() as provider: - return await provider.get_next_branch(current_branch_id) - - @classmethod - async def get_prev_branch(cls, current_branch_id: str) -> Optional[str]: - async with cls.get_persistent_logger() as provider: - return await provider.get_prev_branch(current_branch_id) - - @classmethod - async def branch_at_message(cls, message_id: str) -> str: - async with cls.get_persistent_logger() as provider: - return await provider.branch_at_message(message_id) - - @classmethod - async def delete_conversation(cls, conversation_id: str): - async with cls.get_persistent_logger() as provider: - await provider.delete_conversation(conversation_id) - - @classmethod - async def close(cls): - async with cls.get_persistent_logger() as provider: - await provider.close() diff --git a/py/core/base/logging/run_manager.py b/py/core/base/logging/run_manager.py index 4680068a3..84f576143 100644 --- a/py/core/base/logging/run_manager.py +++ b/py/core/base/logging/run_manager.py @@ -8,13 +8,13 @@ from core.base.logging.base import RunType from core.base.utils import generate_run_id -from .r2r_logger import R2RLoggingProvider +from .r2r_logger import PersistentLoggingProvider run_id_var = contextvars.ContextVar("run_id", default=generate_run_id()) class RunManager: - def __init__(self, logger: R2RLoggingProvider): + def __init__(self, logger: PersistentLoggingProvider): self.logger = logger self.run_info: dict[UUID, dict] = {} diff --git a/py/core/base/pipeline/base_pipeline.py b/py/core/base/pipeline/base_pipeline.py index fa5dc2c0b..99f718102 100644 --- a/py/core/base/pipeline/base_pipeline.py +++ b/py/core/base/pipeline/base_pipeline.py @@ -5,7 +5,7 @@ import traceback from typing import Any, AsyncGenerator, Optional -from ..logging.r2r_logger import R2RLoggingProvider +from ..logging.r2r_logger import PersistentLoggingProvider from ..logging.run_manager import RunManager, manage_run from ..pipes.base_pipe import AsyncPipe, AsyncState @@ -17,13 +17,17 @@ class AsyncPipeline: def __init__( self, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: PersistentLoggingProvider, run_manager: Optional[RunManager] = None, ): + # TODO - Deprecate + if logging_provider is None: + raise ValueError("Pipe logger is required.") + self.pipes: list[AsyncPipe] = [] self.upstream_outputs: list[list[dict[str, str]]] = [] - self.pipe_logger = pipe_logger or R2RLoggingProvider() - self.run_manager = run_manager or RunManager(self.pipe_logger) + self.logging_provider = logging_provider + self.run_manager = run_manager or RunManager(self.logging_provider) self.futures: dict[str, asyncio.Future] = {} self.level = 0 diff --git a/py/core/base/pipes/__init__.py b/py/core/base/pipes/__init__.py index ff9035fab..d5529b49b 100644 --- a/py/core/base/pipes/__init__.py +++ b/py/core/base/pipes/__init__.py @@ -1,3 +1,3 @@ -from .base_pipe import AsyncPipe, AsyncState, PipeType +from .base_pipe import AsyncPipe, AsyncState -__all__ = ["AsyncPipe", "AsyncState", "PipeType"] +__all__ = ["AsyncPipe", "AsyncState"] diff --git a/py/core/base/pipes/base_pipe.py b/py/core/base/pipes/base_pipe.py index 144feedc3..8544a0494 100644 --- a/py/core/base/pipes/base_pipe.py +++ b/py/core/base/pipes/base_pipe.py @@ -8,20 +8,12 @@ from pydantic import BaseModel from core.base.logging import RunType -from core.base.logging.r2r_logger import R2RLoggingProvider +from core.base.logging.r2r_logger import PersistentLoggingProvider from core.base.logging.run_manager import RunManager, manage_run logger = logging.getLogger() -class PipeType(Enum): - INGESTOR = "ingestor" - GENERATOR = "generator" - SEARCH = "search" - TRANSFORM = "transform" - OTHER = "other" - - class AsyncState: """A state object for storing data between pipes.""" @@ -91,34 +83,30 @@ class Config: def __init__( self, config: PipeConfig, - type: PipeType = PipeType.OTHER, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: PersistentLoggingProvider, run_manager: Optional[RunManager] = None, ): + # TODO - Deprecate + if logging_provider is None: + raise ValueError("Pipe logger is required.") + self._config = config or self.PipeConfig() - self._type = type - self.pipe_logger = pipe_logger or R2RLoggingProvider() + self.logging_provider = logging_provider self.log_queue: asyncio.Queue = asyncio.Queue() self.log_worker_task = None - self._run_manager = run_manager or RunManager(self.pipe_logger) + self._run_manager = run_manager or RunManager(self.logging_provider) - logger.debug( - f"Initialized pipe {self.config.name} of type {self.type}" - ) + logger.debug(f"Initialized pipe {self.config.name}") @property def config(self) -> PipeConfig: return self._config - @property - def type(self) -> PipeType: - return self._type - async def log_worker(self): while True: log_data = await self.log_queue.get() run_id, key, value = log_data - await self.pipe_logger.log(run_id, key, value) + await self.logging_provider.log(run_id, key, value) self.log_queue.task_done() async def enqueue_log(self, run_id: UUID, key: str, value: str): diff --git a/py/core/main/abstractions.py b/py/core/main/abstractions.py index 380724533..63bba5cd1 100644 --- a/py/core/main/abstractions.py +++ b/py/core/main/abstractions.py @@ -15,6 +15,7 @@ R2RAuthProvider, R2RIngestionProvider, SimpleOrchestrationProvider, + SqlitePersistentLoggingProvider, SupabaseAuthProvider, UnstructuredIngestionProvider, ) @@ -29,6 +30,7 @@ class R2RProviders(BaseModel): orchestration: Union[ HatchetOrchestrationProvider, SimpleOrchestrationProvider ] + logging: SqlitePersistentLoggingProvider class Config: arbitrary_types_allowed = True diff --git a/py/core/main/assembly/builder.py b/py/core/main/assembly/builder.py index d0ae3b347..f4b1a6ec8 100644 --- a/py/core/main/assembly/builder.py +++ b/py/core/main/assembly/builder.py @@ -11,11 +11,12 @@ DatabaseProvider, EmbeddingProvider, OrchestrationProvider, - R2RLoggingProvider, RunManager, ) from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider +from ..abstractions import R2RProviders from ..api.auth_router import AuthRouter from ..api.ingestion_router import IngestionRouter from ..api.kg_router import KGRouter @@ -159,6 +160,7 @@ def _create_pipes( def _create_pipelines( self, pipeline_factory: type[R2RPipelineFactory], + providers: R2RProviders, pipes: Any, *args, **kwargs, @@ -169,9 +171,9 @@ def _create_pipelines( if v is not None } kwargs.update(override_dict) - return pipeline_factory(self.config, pipes).create_pipelines( - *args, **kwargs - ) + return pipeline_factory( + self.config, providers, pipes + ).create_pipelines(*args, **kwargs) def _create_services( self, service_params: Dict[str, Any] @@ -198,7 +200,7 @@ async def build(self, *args, **kwargs) -> R2RApp: pipe_factory, providers, *args, **kwargs ) pipelines = self._create_pipelines( - pipeline_factory, pipes, *args, **kwargs + pipeline_factory, providers, pipes, *args, **kwargs ) except Exception as e: logger.error(f"Error creating providers, pipes, or pipelines: {e}") @@ -211,8 +213,7 @@ async def build(self, *args, **kwargs) -> R2RApp: overrides={"rag_agent": self.rag_agent_override}, *args, **kwargs ) - run_singleton = R2RLoggingProvider() - run_manager = RunManager(run_singleton) + run_manager = RunManager(providers.logging) service_params = { "config": self.config, @@ -221,7 +222,7 @@ async def build(self, *args, **kwargs) -> R2RApp: "pipelines": pipelines, "agents": agents, "run_manager": run_manager, - "logging_connection": run_singleton, + "logging_connection": providers.logging, } services = self._create_services(service_params) diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index 220fdb61c..d941494f8 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -14,10 +14,10 @@ EmbeddingProvider, IngestionConfig, OrchestrationConfig, - R2RLoggingProvider, ) from core.pipelines import RAGPipeline, SearchPipeline from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders from ..config import R2RConfig @@ -225,6 +225,9 @@ async def create_providers( Union[OpenAICompletionProvider, LiteLLMCompletionProvider] ] = None, orchestration_provider_override: Optional[Any] = None, + r2r_logging_provider_override: Optional[ + SqlitePersistentLoggingProvider + ] = None, *args, **kwargs, ) -> R2RProviders: @@ -272,6 +275,12 @@ async def create_providers( or self.create_orchestration_provider(self.config.orchestration) ) + logging_provider = ( + r2r_logging_provider_override + or SqlitePersistentLoggingProvider(self.config.logging) + ) + await logging_provider.initialize() + return R2RProviders( auth=auth_provider, database=database_provider, @@ -279,6 +288,7 @@ async def create_providers( ingestion=ingestion_provider, llm=llm_provider, orchestration=orchestration_provider, + logging=logging_provider, ) @@ -350,6 +360,7 @@ def create_parsing_pipe(self, *args, **kwargs) -> Any: from core.pipes import ParsingPipe return ParsingPipe( + logging_provider=self.providers.logging, ingestion_provider=self.providers.ingestion, database_provider=self.providers.database, config=AsyncPipe.PipeConfig(name="parsing_pipe"), @@ -362,6 +373,7 @@ def create_embedding_pipe(self, *args, **kwargs) -> Any: from core.pipes import EmbeddingPipe return EmbeddingPipe( + logging_provider=self.providers.logging, embedding_provider=self.providers.embedding, database_provider=self.providers.database, embedding_batch_size=self.config.embedding.batch_size, @@ -375,6 +387,7 @@ def create_vector_storage_pipe(self, *args, **kwargs) -> Any: from core.pipes import VectorStoragePipe return VectorStoragePipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, config=AsyncPipe.PipeConfig(name="vector_storage_pipe"), ) @@ -386,6 +399,7 @@ def create_default_vector_search_pipe(self, *args, **kwargs) -> Any: from core.pipes import VectorSearchPipe return VectorSearchPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, embedding_provider=self.providers.embedding, config=SearchPipe.SearchConfig(name="vector_search_pipe"), @@ -407,6 +421,7 @@ def create_multi_search_pipe( ) query_transform_pipe = QueryTransformPipe( + logging_provider=self.providers.logging, llm_provider=self.providers.llm, database_provider=self.providers.database, config=QueryTransformPipe.QueryTransformConfig( @@ -416,6 +431,7 @@ def create_multi_search_pipe( ) return MultiSearchPipe( + logging_provider=self.providers.logging, query_transform_pipe=query_transform_pipe, inner_search_pipe=inner_search_pipe, config=multi_search_config, @@ -446,6 +462,7 @@ def create_vector_search_pipe(self, *args, **kwargs) -> Any: from core.pipes import RoutingSearchPipe return RoutingSearchPipe( + logging_provider=self.providers.logging, search_pipes={ "vanilla": vanilla_vector_search_pipe, "hyde": hyde_search_pipe, @@ -459,6 +476,7 @@ def create_kg_triples_extraction_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGTriplesExtractionPipe return KGTriplesExtractionPipe( + logging_provider=self.providers.logging, llm_provider=self.providers.llm, database_provider=self.providers.database, config=AsyncPipe.PipeConfig(name="kg_triples_extraction_pipe"), @@ -468,6 +486,7 @@ def create_kg_storage_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGStoragePipe return KGStoragePipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, config=AsyncPipe.PipeConfig(name="kg_storage_pipe"), ) @@ -476,6 +495,7 @@ def create_kg_search_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGSearchSearchPipe return KGSearchSearchPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -489,6 +509,7 @@ def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any: from core.pipes import StreamingSearchRAGPipe return StreamingSearchRAGPipe( + logging_provider=self.providers.logging, llm_provider=self.providers.llm, database_provider=self.providers.database, config=GeneratorPipe.PipeConfig( @@ -499,6 +520,7 @@ def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any: from core.pipes import SearchRAGPipe return SearchRAGPipe( + logging_provider=self.providers.logging, llm_provider=self.providers.llm, database_provider=self.providers.database, config=GeneratorPipe.PipeConfig( @@ -510,6 +532,7 @@ def create_kg_entity_description_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGEntityDescriptionPipe return KGEntityDescriptionPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -520,6 +543,7 @@ def create_kg_clustering_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGClusteringPipe return KGClusteringPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -530,6 +554,7 @@ def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGEntityDeduplicationSummaryPipe return KGEntityDeduplicationSummaryPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -540,6 +565,7 @@ def create_kg_community_summary_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGCommunitySummaryPipe return KGCommunitySummaryPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -550,6 +576,7 @@ def create_kg_entity_deduplication_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGEntityDeduplicationPipe return KGEntityDeduplicationPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -562,6 +589,7 @@ def create_kg_entity_deduplication_summary_pipe( from core.pipes import KGEntityDeduplicationSummaryPipe return KGEntityDeduplicationSummaryPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, embedding_provider=self.providers.embedding, @@ -574,6 +602,7 @@ def create_kg_prompt_tuning_pipe(self, *args, **kwargs) -> Any: from core.pipes import KGPromptTuningPipe return KGPromptTuningPipe( + logging_provider=self.providers.logging, database_provider=self.providers.database, llm_provider=self.providers.llm, config=AsyncPipe.PipeConfig(name="kg_prompt_tuning_pipe"), @@ -581,13 +610,18 @@ def create_kg_prompt_tuning_pipe(self, *args, **kwargs) -> Any: class R2RPipelineFactory: - def __init__(self, config: R2RConfig, pipes: R2RPipes): + def __init__( + self, config: R2RConfig, providers: R2RProviders, pipes: R2RPipes + ): self.config = config + self.providers = providers self.pipes = pipes def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline: """factory method to create an ingestion pipeline.""" - search_pipeline = SearchPipeline() + search_pipeline = SearchPipeline( + logging_provider=self.providers.logging + ) # Add vector search pipes if embedding provider and vector provider is set if ( @@ -614,7 +648,7 @@ def create_rag_pipeline( self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe ) - rag_pipeline = RAGPipeline() + rag_pipeline = RAGPipeline(logging_provider=self.providers.logging) rag_pipeline.set_search_pipeline(search_pipeline) rag_pipeline.add_pipe(rag_pipe) return rag_pipeline @@ -627,10 +661,6 @@ def create_pipelines( *args, **kwargs, ) -> R2RPipelines: - try: - self.configure_logging() - except Exception as e: - logger.warning(f"Error configuring logging: {e}") search_pipeline = search_pipeline or self.create_search_pipeline( *args, **kwargs ) @@ -652,9 +682,6 @@ def create_pipelines( ), ) - def configure_logging(self): - R2RLoggingProvider.configure(self.config.logging) - class R2RAgentFactory: def __init__( diff --git a/py/core/main/config.py b/py/core/main/config.py index a16a52bfc..2e962962f 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -9,7 +9,7 @@ from ..base.abstractions import GenerationConfig from ..base.agent.agent import AgentConfig -from ..base.logging.r2r_logger import LoggingConfig +from ..base.logging.r2r_logger import PersistentLoggingConfig from ..base.providers import AppConfig from ..base.providers.auth import AuthConfig from ..base.providers.crypto import CryptoConfig @@ -65,7 +65,7 @@ class R2RConfig: database: DatabaseConfig embedding: EmbeddingConfig ingestion: IngestionConfig - logging: LoggingConfig + logging: PersistentLoggingConfig agent: AgentConfig orchestration: OrchestrationConfig @@ -117,7 +117,7 @@ def __init__(self, config_data: dict[str, Any]): self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore - self.logging = LoggingConfig.create(**self.logging, app=self.app) # type: ignore + self.logging = PersistentLoggingConfig.create(**self.logging, app=self.app) # type: ignore self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index 4923f7649..8180a1c03 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -2,8 +2,9 @@ from typing import Optional from uuid import UUID -from core.base import R2RException, R2RLoggingProvider, RunManager, Token +from core.base import R2RException, RunManager, Token from core.base.api.models import UserResponse +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders @@ -20,7 +21,7 @@ def __init__( pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, - logging_connection: R2RLoggingProvider, + logging_connection: SqlitePersistentLoggingProvider, ): super().__init__( config, diff --git a/py/core/main/services/base.py b/py/core/main/services/base.py index a9ad965f4..ae4ef3a94 100644 --- a/py/core/main/services/base.py +++ b/py/core/main/services/base.py @@ -1,6 +1,7 @@ from abc import ABC -from core.base import R2RLoggingProvider, RunManager +from core.base import RunManager +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders from ..config import R2RConfig @@ -15,7 +16,7 @@ def __init__( pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, - logging_connection: R2RLoggingProvider, + logging_connection: SqlitePersistentLoggingProvider, ): self.config = config self.providers = providers diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index a332825d3..423cdd56a 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -13,7 +13,6 @@ DocumentType, IngestionStatus, R2RException, - R2RLoggingProvider, RawChunk, RunManager, Vector, @@ -29,6 +28,7 @@ VectorTableName, ) from core.base.api.models import UserResponse +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders @@ -51,7 +51,7 @@ def __init__( pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, - logging_connection: R2RLoggingProvider, + logging_connection: SqlitePersistentLoggingProvider, ) -> None: super().__init__( config, @@ -161,7 +161,7 @@ def _create_document_info_from_file( id=document_id, user_id=user.id, collection_ids=metadata.get("collection_ids", []), - type=DocumentType[file_extension.upper()], + document_type=DocumentType[file_extension.upper()], title=metadata.get("title", file_name.split("/")[-1]), metadata=metadata, version=version, @@ -186,7 +186,7 @@ def _create_document_info_from_chunks( id=document_id, user_id=user.id, collection_ids=metadata.get("collection_ids", []), - type=DocumentType.TXT, + document_type=DocumentType.TXT, title=metadata.get("title", f"Ingested Chunks - {document_id}"), metadata=metadata, version=version, @@ -207,11 +207,11 @@ async def parse_file( id=document_info.id, collection_ids=document_info.collection_ids, user_id=document_info.user_id, - type=document_info.type, metadata={ - "document_type": document_info.type.value, + "document_type": document_info.document_type.value, **document_info.metadata, }, + document_type=document_info.document_type, ) ), state=None, diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index cc3a4e27b..ff16be4a0 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -4,7 +4,7 @@ from typing import AsyncGenerator, Optional from uuid import UUID -from core.base import KGExtractionStatus, R2RLoggingProvider, RunManager +from core.base import KGExtractionStatus, RunManager from core.base.abstractions import ( GenerationConfig, KGCreationSettings, @@ -13,6 +13,7 @@ KGEntityDeduplicationType, R2RException, ) +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders @@ -40,7 +41,7 @@ def __init__( pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, - logging_connection: R2RLoggingProvider, + logging_connection: SqlitePersistentLoggingProvider, ): super().__init__( config, diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 726014a91..3216294b7 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -14,12 +14,12 @@ Message, Prompt, R2RException, - R2RLoggingProvider, RunManager, RunType, UserResponse, ) from core.base.utils import validate_uuid +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders @@ -38,7 +38,7 @@ def __init__( pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, - logging_connection: R2RLoggingProvider, + logging_connection: SqlitePersistentLoggingProvider, ): super().__init__( config, @@ -669,7 +669,7 @@ async def get_conversation( conversation_id: str, branch_id: Optional[str] = None, auth_user=None, - ) -> list[dict]: + ) -> Tuple[str, list[Message]]: return await self.logging_connection.get_conversation( conversation_id, branch_id ) @@ -685,7 +685,7 @@ async def conversations_overview( offset: int = 0, limit: int = 100, auth_user=None, - ) -> list[Dict]: + ) -> dict[str, Union[list[dict], int]]: return await self.logging_connection.get_conversations_overview( conversation_ids=conversation_ids, offset=offset, diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 558d37ab0..d388484a0 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -10,7 +10,6 @@ KGSearchSettings, Message, R2RException, - R2RLoggingProvider, RunManager, RunType, VectorSearchSettings, @@ -18,6 +17,7 @@ to_async_generator, ) from core.base.api.models import RAGResponse, SearchResponse, UserResponse +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders @@ -36,7 +36,7 @@ def __init__( pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, - logging_connection: R2RLoggingProvider, + logging_connection: SqlitePersistentLoggingProvider, ): super().__init__( config, @@ -267,6 +267,11 @@ async def agent( ids = None if not messages: + if not message: + raise R2RException( + status_code=400, + message="Message not provided", + ) # Fetch or create conversation if conversation_id: conversation = ( @@ -282,7 +287,7 @@ async def agent( status_code=404, message=f"Conversation not found: {conversation_id}", ) - messages = [conv[1] for conv in conversation] + [ + messages = [conv[1] for conv in conversation] + [ # type: ignore message ] ids = [conv[0] for conv in conversation] @@ -306,7 +311,7 @@ async def agent( ) ) - current_message = messages[-1] + current_message = messages[-1] # type: ignore # Save the new message to the conversation message_id = await self.logging_connection.add_message( diff --git a/py/core/pipelines/graph_enrichment.py b/py/core/pipelines/graph_enrichment.py index 7f32b4780..1ea8ed88d 100644 --- a/py/core/pipelines/graph_enrichment.py +++ b/py/core/pipelines/graph_enrichment.py @@ -1,10 +1,10 @@ import logging from typing import Optional -from ..base.logging.r2r_logger import R2RLoggingProvider from ..base.logging.run_manager import RunManager from ..base.pipeline.base_pipeline import AsyncPipeline from ..base.pipes.base_pipe import AsyncPipe +from ..providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -16,10 +16,10 @@ class KGEnrichmentPipeline(AsyncPipeline): def __init__( self, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, run_manager: Optional[RunManager] = None, ): - super().__init__(pipe_logger, run_manager) + super().__init__(logging_provider, run_manager) def add_pipe( self, diff --git a/py/core/pipelines/rag_pipeline.py b/py/core/pipelines/rag_pipeline.py index 31cda148a..9b031227b 100644 --- a/py/core/pipelines/rag_pipeline.py +++ b/py/core/pipelines/rag_pipeline.py @@ -8,11 +8,11 @@ VectorSearchSettings, ) from ..base.logging import RunType -from ..base.logging.r2r_logger import R2RLoggingProvider from ..base.logging.run_manager import RunManager, manage_run from ..base.pipeline.base_pipeline import AsyncPipeline from ..base.pipes.base_pipe import AsyncPipe, AsyncState from ..base.utils import to_async_generator +from ..providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -22,10 +22,10 @@ class RAGPipeline(AsyncPipeline): def __init__( self, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, run_manager: Optional[RunManager] = None, ): - super().__init__(pipe_logger, run_manager) + super().__init__(logging_provider, run_manager) self._search_pipeline: Optional[AsyncPipeline] = None self._rag_pipeline: Optional[AsyncPipeline] = None @@ -47,7 +47,9 @@ async def run( # type: ignore self.state = state or AsyncState() # TODO - This feels anti-pattern. run_manager = ( - run_manager or self.run_manager or RunManager(self.pipe_logger) + run_manager + or self.run_manager + or RunManager(self.logging_provider) ) async with manage_run(run_manager, RunType.RETRIEVAL): if not self._search_pipeline: @@ -107,7 +109,9 @@ def add_pipe( "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline" ) if not self._rag_pipeline: - self._rag_pipeline = AsyncPipeline() + self._rag_pipeline = AsyncPipeline( + logging_provider=self.logging_provider + ) self._rag_pipeline.add_pipe( pipe, add_upstream_outputs, *args, **kwargs ) diff --git a/py/core/pipelines/search_pipeline.py b/py/core/pipelines/search_pipeline.py index 0a18d62cf..94ba3c8fc 100644 --- a/py/core/pipelines/search_pipeline.py +++ b/py/core/pipelines/search_pipeline.py @@ -8,10 +8,10 @@ KGSearchSettings, VectorSearchSettings, ) -from ..base.logging.r2r_logger import R2RLoggingProvider from ..base.logging.run_manager import RunManager, manage_run from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests from ..base.pipes.base_pipe import AsyncPipe, AsyncState +from ..providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -21,10 +21,10 @@ class SearchPipeline(AsyncPipeline): def __init__( self, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, run_manager: Optional[RunManager] = None, ): - super().__init__(pipe_logger, run_manager) + super().__init__(logging_provider, run_manager) self._parsing_pipe: Optional[AsyncPipe] = None self._vector_search_pipeline: Optional[AsyncPipeline] = None self._kg_search_pipeline: Optional[AsyncPipeline] = None @@ -125,7 +125,9 @@ def add_pipe( if kg_search_pipe: if not self._kg_search_pipeline: - self._kg_search_pipeline = AsyncPipeline() + self._kg_search_pipeline = AsyncPipeline( + logging_provider=self.logging_provider + ) if not self._kg_search_pipeline: raise ValueError( "KG search pipeline not found" @@ -136,7 +138,9 @@ def add_pipe( ) elif vector_search_pipe: if not self._vector_search_pipeline: - self._vector_search_pipeline = AsyncPipeline() + self._vector_search_pipeline = AsyncPipeline( + logging_provider=self.logging_provider + ) if not self._vector_search_pipeline: raise ValueError( "Vector search pipeline not found" diff --git a/py/core/pipes/abstractions/generator_pipe.py b/py/core/pipes/abstractions/generator_pipe.py index 0ebe5ea43..8ddd8f1ac 100644 --- a/py/core/pipes/abstractions/generator_pipe.py +++ b/py/core/pipes/abstractions/generator_pipe.py @@ -2,15 +2,10 @@ from typing import Any, AsyncGenerator, Optional from uuid import UUID -from core.base import ( - AsyncState, - CompletionProvider, - DatabaseProvider, - PipeType, - R2RLoggingProvider, -) +from core.base import AsyncState, CompletionProvider, DatabaseProvider from core.base.abstractions import GenerationConfig from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider class GeneratorPipe(AsyncPipe): @@ -24,15 +19,13 @@ def __init__( llm_provider: CompletionProvider, database_provider: DatabaseProvider, config: AsyncPipe.PipeConfig, - type: PipeType = PipeType.GENERATOR, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): super().__init__( config, - type, - pipe_logger, + logging_provider, *args, **kwargs, ) diff --git a/py/core/pipes/abstractions/search_pipe.py b/py/core/pipes/abstractions/search_pipe.py index 8db2c506a..6ac60116e 100644 --- a/py/core/pipes/abstractions/search_pipe.py +++ b/py/core/pipes/abstractions/search_pipe.py @@ -3,13 +3,8 @@ from typing import Any, AsyncGenerator, Optional, Union from uuid import UUID -from core.base import ( - AsyncPipe, - AsyncState, - PipeType, - R2RLoggingProvider, - VectorSearchResult, -) +from core.base import AsyncPipe, AsyncState, VectorSearchResult +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -26,15 +21,13 @@ class Input(AsyncPipe.Input): def __init__( self, config: AsyncPipe.PipeConfig, - type: PipeType = PipeType.SEARCH, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): super().__init__( config, - type, - pipe_logger, + logging_provider, *args, **kwargs, ) diff --git a/py/core/pipes/ingestion/embedding_pipe.py b/py/core/pipes/ingestion/embedding_pipe.py index a663bedf3..97c45bad6 100644 --- a/py/core/pipes/ingestion/embedding_pipe.py +++ b/py/core/pipes/ingestion/embedding_pipe.py @@ -6,13 +6,12 @@ AsyncState, DocumentExtraction, EmbeddingProvider, - PipeType, R2RDocumentProcessingError, - R2RLoggingProvider, Vector, VectorEntry, ) from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -29,16 +28,14 @@ def __init__( self, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, + logging_provider: SqlitePersistentLoggingProvider, embedding_batch_size: int = 1, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.INGESTOR, *args, **kwargs, ): super().__init__( config, - type, - pipe_logger, + logging_provider, ) self.embedding_provider = embedding_provider self.embedding_batch_size = embedding_batch_size diff --git a/py/core/pipes/ingestion/parsing_pipe.py b/py/core/pipes/ingestion/parsing_pipe.py index 0cd1141cc..6eea6834f 100644 --- a/py/core/pipes/ingestion/parsing_pipe.py +++ b/py/core/pipes/ingestion/parsing_pipe.py @@ -7,12 +7,11 @@ DatabaseProvider, Document, DocumentExtraction, - PipeType, - R2RLoggingProvider, ) from core.base.abstractions import R2RDocumentProcessingError from core.base.pipes.base_pipe import AsyncPipe from core.base.providers.ingestion import IngestionProvider +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from core.utils import generate_extraction_id logger = logging.getLogger() @@ -27,15 +26,13 @@ def __init__( database_provider: DatabaseProvider, ingestion_provider: IngestionProvider, config: AsyncPipe.PipeConfig, - type: PipeType = PipeType.INGESTOR, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): super().__init__( config, - type, - pipe_logger, + logging_provider, *args, **kwargs, ) diff --git a/py/core/pipes/ingestion/vector_storage_pipe.py b/py/core/pipes/ingestion/vector_storage_pipe.py index 9f2c5925a..86f9ac6d3 100644 --- a/py/core/pipes/ingestion/vector_storage_pipe.py +++ b/py/core/pipes/ingestion/vector_storage_pipe.py @@ -2,15 +2,9 @@ from typing import Any, AsyncGenerator, Optional from uuid import UUID -from core.base import ( - AsyncState, - DatabaseProvider, - PipeType, - R2RLoggingProvider, - StorageResult, - VectorEntry, -) +from core.base import AsyncState, DatabaseProvider, StorageResult, VectorEntry from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -23,9 +17,8 @@ def __init__( self, database_provider: DatabaseProvider, config: AsyncPipe.PipeConfig, + logging_provider: SqlitePersistentLoggingProvider, storage_batch_size: int = 128, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.INGESTOR, *args, **kwargs, ): @@ -34,8 +27,7 @@ def __init__( """ super().__init__( config, - type, - pipe_logger, + logging_provider, *args, **kwargs, ) diff --git a/py/core/pipes/kg/clustering.py b/py/core/pipes/kg/clustering.py index 082a35d39..967995123 100644 --- a/py/core/pipes/kg/clustering.py +++ b/py/core/pipes/kg/clustering.py @@ -8,9 +8,8 @@ CompletionProvider, DatabaseProvider, EmbeddingProvider, - PipeType, - R2RLoggingProvider, ) +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -26,8 +25,7 @@ def __init__( llm_provider: CompletionProvider, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.OTHER, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): @@ -35,8 +33,7 @@ def __init__( Initializes the KG clustering pipe with necessary components and configurations. """ super().__init__( - pipe_logger=pipe_logger, - type=type, + logging_provider=logging_provider, config=config or AsyncPipe.PipeConfig(name="kg_cluster_pipe"), ) self.database_provider = database_provider diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index b8783bc54..457e84c67 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -14,10 +14,9 @@ DatabaseProvider, EmbeddingProvider, GenerationConfig, - PipeType, - R2RLoggingProvider, ) from core.base.abstractions import Entity, Triple +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -33,8 +32,7 @@ def __init__( llm_provider: CompletionProvider, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.OTHER, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): @@ -42,8 +40,7 @@ def __init__( Initializes the KG clustering pipe with necessary components and configurations. """ super().__init__( - pipe_logger=pipe_logger, - type=type, + logging_provider=logging_provider, config=config or AsyncPipe.PipeConfig(name="kg_community_summary_pipe"), ) diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index fe0061507..4119d9e6c 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -4,8 +4,7 @@ from core.base import AsyncState, R2RException from core.base.abstractions import Entity, KGEntityDeduplicationType -from core.base.logging import R2RLoggingProvider -from core.base.pipes import AsyncPipe, PipeType +from core.base.pipes import AsyncPipe from core.providers import ( LiteLLMCompletionProvider, LiteLLMEmbeddingProvider, @@ -13,6 +12,7 @@ OpenAIEmbeddingProvider, PostgresDBProvider, ) +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -28,13 +28,11 @@ def __init__( embedding_provider: Union[ LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider ], - type: PipeType = PipeType.OTHER, - pipe_logger: Optional[R2RLoggingProvider] = None, + logging_provider: SqlitePersistentLoggingProvider, **kwargs, ): super().__init__( - pipe_logger=pipe_logger, - type=type, + logging_provider=logging_provider, config=config or AsyncPipe.PipeConfig(name="kg_entity_deduplication_pipe"), ) diff --git a/py/core/pipes/kg/deduplication_summary.py b/py/core/pipes/kg/deduplication_summary.py index 919543677..7819cfaad 100644 --- a/py/core/pipes/kg/deduplication_summary.py +++ b/py/core/pipes/kg/deduplication_summary.py @@ -5,8 +5,7 @@ from core.base import AsyncState from core.base.abstractions import Entity, GenerationConfig -from core.base.logging import R2RLoggingProvider -from core.base.pipes import AsyncPipe, PipeType +from core.base.pipes import AsyncPipe from core.providers import ( LiteLLMCompletionProvider, LiteLLMEmbeddingProvider, @@ -14,6 +13,7 @@ OpenAIEmbeddingProvider, PostgresDBProvider, ) +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -33,12 +33,11 @@ def __init__( LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider ], config: AsyncPipe.PipeConfig, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.OTHER, + logging_provider: SqlitePersistentLoggingProvider, **kwargs, ): super().__init__( - pipe_logger=pipe_logger, type=type, config=config, **kwargs + logging_provider=logging_provider, config=config, **kwargs ) self.database_provider = database_provider self.llm_provider = llm_provider diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 40194b597..2e82366b0 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -12,11 +12,10 @@ CompletionProvider, DatabaseProvider, EmbeddingProvider, - PipeType, - R2RLoggingProvider, ) from core.base.abstractions import Entity from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -35,14 +34,12 @@ def __init__( llm_provider: CompletionProvider, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.OTHER, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): super().__init__( - pipe_logger=pipe_logger, - type=type, + logging_provider=logging_provider, config=config, ) self.database_provider = database_provider diff --git a/py/core/pipes/kg/prompt_tuning.py b/py/core/pipes/kg/prompt_tuning.py index 543756a5a..ff39e80e0 100644 --- a/py/core/pipes/kg/prompt_tuning.py +++ b/py/core/pipes/kg/prompt_tuning.py @@ -10,11 +10,10 @@ AsyncState, CompletionProvider, DatabaseProvider, - PipeType, R2RException, - R2RLoggingProvider, ) from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -29,14 +28,12 @@ def __init__( database_provider: DatabaseProvider, llm_provider: CompletionProvider, config: AsyncPipe.PipeConfig, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.OTHER, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): super().__init__( - pipe_logger=pipe_logger, - type=type, + logging_provider=logging_provider, config=config, ) self.database_provider = database_provider diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index a98e93ecb..f159e3ab9 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -8,11 +8,10 @@ DatabaseProvider, EmbeddingProvider, KGExtraction, - PipeType, R2RDocumentProcessingError, - R2RLoggingProvider, ) from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -26,9 +25,8 @@ def __init__( self, database_provider: DatabaseProvider, config: AsyncPipe.PipeConfig, + logging_provider: SqlitePersistentLoggingProvider, storage_batch_size: int = 1, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.INGESTOR, *args, **kwargs, ): @@ -41,8 +39,7 @@ def __init__( super().__init__( config, - type, - pipe_logger, + logging_provider, *args, **kwargs, ) diff --git a/py/core/pipes/kg/triples_extraction.py b/py/core/pipes/kg/triples_extraction.py index 56b77ebd4..c3dd9f0ad 100644 --- a/py/core/pipes/kg/triples_extraction.py +++ b/py/core/pipes/kg/triples_extraction.py @@ -13,13 +13,12 @@ Entity, GenerationConfig, KGExtraction, - PipeType, R2RDocumentProcessingError, R2RException, - R2RLoggingProvider, Triple, ) from core.base.pipes.base_pipe import AsyncPipe +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -47,17 +46,15 @@ def __init__( database_provider: DatabaseProvider, llm_provider: CompletionProvider, config: AsyncPipe.PipeConfig, + logging_provider: SqlitePersistentLoggingProvider, kg_batch_size: int = 1, graph_rag: bool = True, id_prefix: str = "demo", - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.INGESTOR, *args, **kwargs, ): super().__init__( - pipe_logger=pipe_logger, - type=type, + logging_provider=logging_provider, config=config or AsyncPipe.PipeConfig(name="default_kg_triples_extraction_pipe"), ) diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index b57920f00..7f718e6db 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -8,8 +8,6 @@ CompletionProvider, DatabaseProvider, EmbeddingProvider, - PipeType, - R2RLoggingProvider, ) from core.base.abstractions import ( KGCommunityResult, @@ -19,6 +17,7 @@ KGSearchResultType, KGSearchSettings, ) +from core.providers.logging.r2r_logging import SqlitePersistentLoggingProvider from ..abstractions.generator_pipe import GeneratorPipe @@ -36,8 +35,7 @@ def __init__( database_provider: DatabaseProvider, embedding_provider: EmbeddingProvider, config: GeneratorPipe.PipeConfig, - pipe_logger: Optional[R2RLoggingProvider] = None, - type: PipeType = PipeType.INGESTOR, + logging_provider: SqlitePersistentLoggingProvider, *args, **kwargs, ): @@ -48,8 +46,7 @@ def __init__( llm_provider, database_provider, config, - type, - pipe_logger, + logging_provider, *args, **kwargs, ) diff --git a/py/core/pipes/retrieval/multi_search.py b/py/core/pipes/retrieval/multi_search.py index f2e38d139..6e2dffebd 100644 --- a/py/core/pipes/retrieval/multi_search.py +++ b/py/core/pipes/retrieval/multi_search.py @@ -7,7 +7,7 @@ VectorSearchResult, VectorSearchSettings, ) -from core.base.pipes.base_pipe import AsyncPipe, PipeType +from core.base.pipes.base_pipe import AsyncPipe from ..abstractions.search_pipe import SearchPipe from .query_transform_pipe import QueryTransformPipe @@ -37,7 +37,6 @@ def __init__( ) super().__init__( config, - PipeType.SEARCH, *args, **kwargs, ) diff --git a/py/core/pipes/retrieval/query_transform_pipe.py b/py/core/pipes/retrieval/query_transform_pipe.py index 271d5346b..3a949235d 100644 --- a/py/core/pipes/retrieval/query_transform_pipe.py +++ b/py/core/pipes/retrieval/query_transform_pipe.py @@ -7,7 +7,6 @@ AsyncState, CompletionProvider, DatabaseProvider, - PipeType, ) from core.base.abstractions import GenerationConfig @@ -30,7 +29,6 @@ def __init__( llm_provider: CompletionProvider, database_provider: DatabaseProvider, config: QueryTransformConfig, - type: PipeType = PipeType.TRANSFORM, *args, **kwargs, ): @@ -39,7 +37,6 @@ def __init__( llm_provider, database_provider, config, - type, *args, **kwargs, ) diff --git a/py/core/pipes/retrieval/search_rag_pipe.py b/py/core/pipes/retrieval/search_rag_pipe.py index a44b153c9..a7cc4635d 100644 --- a/py/core/pipes/retrieval/search_rag_pipe.py +++ b/py/core/pipes/retrieval/search_rag_pipe.py @@ -7,7 +7,6 @@ AsyncState, CompletionProvider, DatabaseProvider, - PipeType, ) from core.base.abstractions import GenerationConfig, RAGCompletion @@ -23,7 +22,6 @@ def __init__( llm_provider: CompletionProvider, database_provider: DatabaseProvider, config: GeneratorPipe.PipeConfig, - type: PipeType = PipeType.GENERATOR, *args, **kwargs, ): @@ -31,7 +29,6 @@ def __init__( llm_provider, database_provider, config, - type, *args, **kwargs, ) diff --git a/py/core/pipes/retrieval/streaming_rag_pipe.py b/py/core/pipes/retrieval/streaming_rag_pipe.py index fcc9ea407..8871519db 100644 --- a/py/core/pipes/retrieval/streaming_rag_pipe.py +++ b/py/core/pipes/retrieval/streaming_rag_pipe.py @@ -8,7 +8,6 @@ CompletionProvider, DatabaseProvider, LLMChatCompletionChunk, - PipeType, format_search_results_for_llm, format_search_results_for_stream, ) @@ -32,7 +31,6 @@ def __init__( llm_provider: CompletionProvider, database_provider: DatabaseProvider, config: GeneratorPipe.PipeConfig, - type: PipeType = PipeType.GENERATOR, *args, **kwargs, ): @@ -40,7 +38,6 @@ def __init__( llm_provider, database_provider, config, - type, *args, **kwargs, ) diff --git a/py/core/pipes/retrieval/vector_search_pipe.py b/py/core/pipes/retrieval/vector_search_pipe.py index dbeccebdd..645cfe91e 100644 --- a/py/core/pipes/retrieval/vector_search_pipe.py +++ b/py/core/pipes/retrieval/vector_search_pipe.py @@ -9,7 +9,6 @@ DatabaseProvider, EmbeddingProvider, EmbeddingPurpose, - PipeType, VectorSearchResult, VectorSearchSettings, ) @@ -25,13 +24,11 @@ def __init__( database_provider: DatabaseProvider, embedding_provider: EmbeddingProvider, config: SearchPipe.SearchConfig, - type: PipeType = PipeType.SEARCH, *args, **kwargs, ): super().__init__( config, - type, *args, **kwargs, ) diff --git a/py/core/providers/__init__.py b/py/core/providers/__init__.py index fabb97825..393b72e0c 100644 --- a/py/core/providers/__init__.py +++ b/py/core/providers/__init__.py @@ -9,6 +9,7 @@ UnstructuredIngestionProvider, ) from .llm import LiteLLMCompletionProvider, OpenAICompletionProvider +from .logging import SqlitePersistentLoggingProvider from .orchestration import ( HatchetOrchestrationProvider, SimpleOrchestrationProvider, @@ -37,4 +38,6 @@ # LLM "OpenAICompletionProvider", "LiteLLMCompletionProvider", + # Logging + "SqlitePersistentLoggingProvider", ] diff --git a/py/core/providers/database/collection.py b/py/core/providers/database/collection.py index d78f7ec12..7dff5d8b6 100644 --- a/py/core/providers/database/collection.py +++ b/py/core/providers/database/collection.py @@ -325,7 +325,7 @@ async def documents_in_collection( id=row["document_id"], collection_ids=[collection_id], user_id=row["user_id"], - type=DocumentType(row["type"]), + document_type=DocumentType(row["type"]), metadata=json.loads(row["metadata"]), title=row["title"], version=row["version"], diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index d889e92af..0abb90b61 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -398,7 +398,7 @@ async def get_documents_overview( id=row["document_id"], collection_ids=row["collection_ids"], user_id=row["user_id"], - type=DocumentType(row["type"]), + document_type=DocumentType(row["type"]), metadata=json.loads(row["metadata"]), title=row["title"], version=row["version"], diff --git a/py/core/providers/ingestion/r2r/base.py b/py/core/providers/ingestion/r2r/base.py index 94ca81e63..2632644dc 100644 --- a/py/core/providers/ingestion/r2r/base.py +++ b/py/core/providers/ingestion/r2r/base.py @@ -179,10 +179,10 @@ async def parse( # type: ignore ) -> AsyncGenerator[ Union[DocumentExtraction, R2RDocumentProcessingError], None ]: - if document.type not in self.parsers: + if document.document_type not in self.parsers: yield R2RDocumentProcessingError( document_id=document.id, - error_message=f"Parser for {document.type} not found in `R2RIngestionProvider`.", + error_message=f"Parser for {document.document_type} not found in `R2RIngestionProvider`.", ) else: t0 = time.time() @@ -190,13 +190,13 @@ async def parse( # type: ignore parser_overrides = ingestion_config_override.get( "parser_overrides", {} ) - if document.type.value in parser_overrides: + if document.document_type.value in parser_overrides: logger.info( - f"Using parser_override for {document.type} with input value {parser_overrides[document.type.value]}" + f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}" ) # TODO - Cleanup this approach to be less hardcoded if ( - document.type != DocumentType.PDF + document.document_type != DocumentType.PDF or parser_overrides[DocumentType.PDF.value] != "zerox" ): raise ValueError( @@ -207,7 +207,7 @@ async def parse( # type: ignore ].ingest(file_content, **ingestion_config_override): contents += text + "\n" else: - async for text in self.parsers[document.type].ingest( + async for text in self.parsers[document.document_type].ingest( file_content, **ingestion_config_override ): contents += text + "\n" diff --git a/py/core/providers/ingestion/unstructured/base.py b/py/core/providers/ingestion/unstructured/base.py index 228ae5df5..a1d57af9e 100644 --- a/py/core/providers/ingestion/unstructured/base.py +++ b/py/core/providers/ingestion/unstructured/base.py @@ -215,9 +215,9 @@ async def parse( # TODO - Cleanup this approach to be less hardcoded # TODO - Remove code duplication between Unstructured & R2R - if document.type.value in parser_overrides: + if document.document_type.value in parser_overrides: logger.info( - f"Using parser_override for {document.type} with input value {parser_overrides[document.type.value]}" + f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}" ) async for element in self.parse_fallback( file_content, @@ -226,19 +226,19 @@ async def parse( ): elements.append(element) - elif document.type in self.R2R_FALLBACK_PARSERS.keys(): + elif document.document_type in self.R2R_FALLBACK_PARSERS.keys(): logger.info( - f"Parsing {document.type}: {document.id} with fallback parser" + f"Parsing {document.document_type}: {document.id} with fallback parser" ) async for element in self.parse_fallback( file_content, ingestion_config=ingestion_config, - parser_name=document.type, + parser_name=document.document_type, ): elements.append(element) else: logger.info( - f"Parsing {document.type}: {document.id} with unstructured" + f"Parsing {document.document_type}: {document.id} with unstructured" ) if isinstance(file_content, bytes): file_content = BytesIO(file_content) # type: ignore diff --git a/py/core/providers/logging/__init__.py b/py/core/providers/logging/__init__.py new file mode 100644 index 000000000..7f18514b5 --- /dev/null +++ b/py/core/providers/logging/__init__.py @@ -0,0 +1,3 @@ +from .r2r_logging import SqlitePersistentLoggingProvider + +__all_ = ["SqlitePersistentLoggingProvider"] diff --git a/py/core/providers/logging/r2r_logging.py b/py/core/providers/logging/r2r_logging.py new file mode 100644 index 000000000..e6bcac2e2 --- /dev/null +++ b/py/core/providers/logging/r2r_logging.py @@ -0,0 +1,817 @@ +import json +import os +import uuid +from datetime import datetime +from typing import Optional, Tuple, Union +from uuid import UUID + +from core.base import Message +from core.base.logging.base import RunType +from core.base.logging.r2r_logger import ( + PersistentLoggingConfig, + PersistentLoggingProvider, + RunInfoLog, + logger, +) + + +class SqlitePersistentLoggingProvider(PersistentLoggingProvider): + def __init__(self, config: PersistentLoggingConfig): + self.log_table = config.log_table + self.log_info_table = config.log_info_table + # TODO - Should we re-consider this naming convention? + self.project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default") + self.logging_path = config.logging_path or os.getenv( + "LOCAL_DB_PATH", "local.sqlite" + ) + if not self.logging_path: + raise ValueError( + "Please set the environment variable LOCAL_DB_PATH." + ) + self.conn = None + try: + import aiosqlite + + self.aiosqlite = aiosqlite + except ImportError: + raise ImportError( + "Please install aiosqlite to use the SqlitePersistentLoggingProvider." + ) + + async def initialize(self): + self.conn = await self.aiosqlite.connect(self.logging_path) + + await self.conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.project_name}_{self.log_table} ( + timestamp DATETIME, + run_id TEXT, + key TEXT, + value TEXT + ) + """ + ) + await self.conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.project_name}_{self.log_info_table} ( + timestamp DATETIME, + run_id TEXT UNIQUE, + run_type TEXT, + user_id TEXT + ) + """ + ) + await self.conn.executescript( + """ + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + created_at REAL + ); + + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + parent_id TEXT, + content TEXT, + created_at REAL, + metadata TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations(id), + FOREIGN KEY (parent_id) REFERENCES messages(id) + ); + + CREATE TABLE IF NOT EXISTS branches ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + branch_point_id TEXT, + created_at REAL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id), + FOREIGN KEY (branch_point_id) REFERENCES messages(id) + ); + + CREATE TABLE IF NOT EXISTS message_branches ( + message_id TEXT, + branch_id TEXT, + PRIMARY KEY (message_id, branch_id), + FOREIGN KEY (message_id) REFERENCES messages(id), + FOREIGN KEY (branch_id) REFERENCES branches(id) + ); + """ + ) + await self.conn.commit() + + async def __aenter__(self): + if self.conn is None: + await self._init() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + if self.conn: + await self.conn.close() + self.conn = None + + async def log( + self, + run_id: UUID, + key: str, + value: str, + ): + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + await self.conn.execute( + f""" + INSERT INTO {self.project_name}_{self.log_table} (timestamp, run_id, key, value) + VALUES (datetime('now'), ?, ?, ?) + """, + (str(run_id), key, value), + ) + await self.conn.commit() + + async def info_log( + self, + run_id: UUID, + run_type: RunType, + user_id: UUID, + ): + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + await self.conn.execute( + f""" + INSERT INTO {self.project_name}_{self.log_info_table} (timestamp, run_id, run_type, user_id) + VALUES (datetime('now'), ?, ?, ?) + ON CONFLICT(run_id) DO UPDATE SET + timestamp = datetime('now'), + run_type = excluded.run_type, + user_id = excluded.user_id + """, + (str(run_id), run_type, str(user_id)), + ) + await self.conn.commit() + + async def get_info_logs( + self, + offset: int = 0, + limit: int = 100, + run_type_filter: Optional[RunType] = None, + user_ids: Optional[list[UUID]] = None, + ) -> list[RunInfoLog]: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + cursor = await self.conn.cursor() + query = "SELECT run_id, run_type, timestamp, user_id" + query += f" FROM {self.project_name}_{self.log_info_table}" + conditions = [] + params = [] + if run_type_filter: + conditions.append("run_type = ?") + params.append(run_type_filter) + if user_ids: + conditions.append(f"user_id IN ({','.join(['?']*len(user_ids))})") + params.extend([str(user_id) for user_id in user_ids]) + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += " ORDER BY timestamp DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + await cursor.execute(query, params) + rows = await cursor.fetchall() + return [ + RunInfoLog( + run_id=UUID(row[0]), + run_type=row[1], + timestamp=datetime.fromisoformat(row[2]), + user_id=UUID(row[3]), + ) + for row in rows + ] + + async def create_conversation(self) -> str: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + conversation_id = str(uuid.uuid4()) + created_at = datetime.utcnow().timestamp() + + await self.conn.execute( + "INSERT INTO conversations (id, created_at) VALUES (?, ?)", + (conversation_id, created_at), + ) + await self.conn.commit() + return conversation_id + + async def get_conversations_overview( + self, + conversation_ids: Optional[list[UUID]] = None, + offset: int = 0, + limit: int = -1, + ) -> dict[str, Union[list[dict], int]]: + """Get an overview of conversations, optionally filtered by conversation IDs, with pagination.""" + query = """ + WITH conversation_overview AS ( + SELECT c.id, c.created_at + FROM conversations c + {where_clause} + ), + counted_overview AS ( + SELECT *, COUNT(*) OVER() AS total_entries + FROM conversation_overview + ) + SELECT * FROM counted_overview + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """ + + where_clause = ( + f"WHERE c.id IN ({','.join(['?' for _ in conversation_ids])})" + if conversation_ids + else "" + ) + query = query.format(where_clause=where_clause) + + params: list = [] + if conversation_ids: + params.extend(conversation_ids) + params.extend((limit if limit != -1 else -1, offset)) + + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + async with self.conn.execute(query, params) as cursor: + results = await cursor.fetchall() + + if not results: + logger.info("No conversations found.") + return {"results": [], "total_entries": 0} + + conversations = [ + { + "conversation_id": row[0], + "created_at": row[1], + } + for row in results + ] + + total_entries = results[0][-1] if results else 0 + + return {"results": conversations, "total_entries": total_entries} + + async def add_message( + self, + conversation_id: str, + content: Message, + parent_id: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> str: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + message_id = str(uuid.uuid4()) + created_at = datetime.utcnow().timestamp() + + await self.conn.execute( + "INSERT INTO messages (id, conversation_id, parent_id, content, created_at, metadata) VALUES (?, ?, ?, ?, ?, ?)", + ( + message_id, + conversation_id, + parent_id, + content.json(), + created_at, + json.dumps(metadata or {}), + ), + ) + + if parent_id is not None: + await self.conn.execute( + """ + INSERT INTO message_branches (message_id, branch_id) + SELECT ?, branch_id FROM message_branches WHERE message_id = ? + """, + (message_id, parent_id), + ) + else: + # For messages with no parent, use the most recent branch, or create a new one + async with self.conn.execute( + """ + SELECT id FROM branches + WHERE conversation_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (conversation_id,), + ) as cursor: + row = await cursor.fetchone() + if row is not None: + branch_id = row[0] + else: + # Create a new branch if none exists + branch_id = str(uuid.uuid4()) + await self.conn.execute( + """ + INSERT INTO branches (id, conversation_id, branch_point_id) VALUES (?, ?, NULL) + """, + (branch_id, conversation_id), + ) + await self.conn.execute( + """ + INSERT INTO message_branches (message_id, branch_id) VALUES (?, ?) + """, + (message_id, branch_id), + ) + + await self.conn.commit() + return message_id + + async def edit_message( + self, message_id: str, new_content: str + ) -> Tuple[str, str]: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + # Get the original message details + async with self.conn.execute( + "SELECT conversation_id, parent_id, content FROM messages WHERE id = ?", + (message_id,), + ) as cursor: + row = await cursor.fetchone() + if row is None: + raise ValueError(f"Message {message_id} not found") + conversation_id, parent_id, old_content_json = row + # Parse the old content to get the original Message object + old_message = Message.parse_raw(old_content_json) + + # Create a new Message object with the updated content + edited_message = Message( + role=old_message.role, + content=new_content, + name=old_message.name, + function_call=old_message.function_call, + tool_calls=old_message.tool_calls, + ) + + # Create a new branch + new_branch_id = str(uuid.uuid4()) + created_at = datetime.utcnow().timestamp() + await self.conn.execute( + "INSERT INTO branches (id, conversation_id, branch_point_id, created_at) VALUES (?, ?, ?, ?)", + (new_branch_id, conversation_id, message_id, created_at), + ) + + # Add the edited message with the same parent_id + new_message_id = str(uuid.uuid4()) + message_created_at = datetime.utcnow().timestamp() + await self.conn.execute( + "INSERT INTO messages (id, conversation_id, parent_id, content, created_at, metadata) VALUES (?, ?, ?, ?, ?, ?)", + ( + new_message_id, + conversation_id, + parent_id, + edited_message.json(), + message_created_at, + json.dumps({"edited": True}), + ), + ) + # Link the new message to the new branch + await self.conn.execute( + "INSERT INTO message_branches (message_id, branch_id) VALUES (?, ?)", + (new_message_id, new_branch_id), + ) + + # Link ancestor messages (excluding the original message) to the new branch + await self.conn.execute( + """ + WITH RECURSIVE ancestors(id) AS ( + SELECT parent_id FROM messages WHERE id = ? + UNION ALL + SELECT m.parent_id FROM messages m JOIN ancestors a ON m.id = a.id WHERE m.parent_id IS NOT NULL + ) + INSERT OR IGNORE INTO message_branches (message_id, branch_id) + SELECT id, ? FROM ancestors WHERE id IS NOT NULL + """, + (message_id, new_branch_id), + ) + + # Update the parent_id of the edited message's descendants in the new branch + await self.conn.execute( + """ + WITH RECURSIVE descendants(id) AS ( + SELECT id FROM messages WHERE parent_id = ? + UNION ALL + SELECT m.id FROM messages m JOIN descendants d ON m.parent_id = d.id + ) + UPDATE messages + SET parent_id = ? + WHERE id IN (SELECT id FROM descendants) + """, + (message_id, new_message_id), + ) + + await self.conn.commit() + return new_message_id, new_branch_id + + async def get_conversation( + self, conversation_id: str, branch_id: Optional[str] = None + ) -> Tuple[str, list[Message]]: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + if branch_id is None: + # Get the most recent branch by created_at timestamp + async with self.conn.execute( + """ + SELECT id FROM branches + WHERE conversation_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (conversation_id,), + ) as cursor: + row = await cursor.fetchone() + branch_id = row[0] if row else None + + if branch_id is None: + return [] # No branches found for the conversation + + # Get all messages for this branch + async with self.conn.execute( + """ + WITH RECURSIVE branch_messages(id, content, parent_id, depth, created_at) AS ( + SELECT m.id, m.content, m.parent_id, 0, m.created_at + FROM messages m + JOIN message_branches mb ON m.id = mb.message_id + WHERE mb.branch_id = ? AND m.parent_id IS NULL + UNION + SELECT m.id, m.content, m.parent_id, bm.depth + 1, m.created_at + FROM messages m + JOIN message_branches mb ON m.id = mb.message_id + JOIN branch_messages bm ON m.parent_id = bm.id + WHERE mb.branch_id = ? + ) + SELECT id, content, parent_id FROM branch_messages + ORDER BY created_at ASC + """, + (branch_id, branch_id), + ) as cursor: + rows = await cursor.fetchall() + return [(row[0], Message.parse_raw(row[1])) for row in rows] + + async def get_branches_overview(self, conversation_id: str) -> list[dict]: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + async with self.conn.execute( + """ + SELECT b.id, b.branch_point_id, m.content, b.created_at + FROM branches b + LEFT JOIN messages m ON b.branch_point_id = m.id + WHERE b.conversation_id = ? + ORDER BY b.created_at + """, + (conversation_id,), + ) as cursor: + rows = await cursor.fetchall() + return [ + { + "branch_id": row[0], + "branch_point_id": row[1], + "content": row[2], + "created_at": row[3], + } + for row in rows + ] + + async def get_next_branch(self, current_branch_id: str) -> Optional[str]: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + async with self.conn.execute( + """ + SELECT id FROM branches + WHERE conversation_id = (SELECT conversation_id FROM branches WHERE id = ?) + AND created_at > (SELECT created_at FROM branches WHERE id = ?) + ORDER BY created_at + LIMIT 1 + """, + (current_branch_id, current_branch_id), + ) as cursor: + row = await cursor.fetchone() + return row[0] if row else None + + async def get_prev_branch(self, current_branch_id: str) -> Optional[str]: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + async with self.conn.execute( + """ + SELECT id FROM branches + WHERE conversation_id = (SELECT conversation_id FROM branches WHERE id = ?) + AND created_at < (SELECT created_at FROM branches WHERE id = ?) + ORDER BY created_at DESC + LIMIT 1 + """, + (current_branch_id, current_branch_id), + ) as cursor: + row = await cursor.fetchone() + return row[0] if row else None + + async def branch_at_message(self, message_id: str) -> str: + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + # Get the conversation_id of the message + async with self.conn.execute( + "SELECT conversation_id FROM messages WHERE id = ?", + (message_id,), + ) as cursor: + row = await cursor.fetchone() + if row is None: + raise ValueError(f"Message {message_id} not found") + conversation_id = row[0] + + # Check if the message is already a branch point + async with self.conn.execute( + "SELECT id FROM branches WHERE branch_point_id = ?", + (message_id,), + ) as cursor: + row = await cursor.fetchone() + if row is not None: + return row[0] # Return the existing branch ID + + # Create a new branch starting from message_id + new_branch_id = str(uuid.uuid4()) + await self.conn.execute( + "INSERT INTO branches (id, conversation_id, branch_point_id) VALUES (?, ?, ?)", + (new_branch_id, conversation_id, message_id), + ) + + # Link ancestor messages to the new branch + await self.conn.execute( + """ + WITH RECURSIVE ancestors(id) AS ( + SELECT id FROM messages WHERE id = ? + UNION ALL + SELECT m.parent_id FROM messages m JOIN ancestors a ON m.id = a.id WHERE m.parent_id IS NOT NULL + ) + INSERT OR IGNORE INTO message_branches (message_id, branch_id) + SELECT id, ? FROM ancestors + """, + (message_id, new_branch_id), + ) + + await self.conn.commit() + return new_branch_id + + async def delete_conversation(self, conversation_id: str): + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + # Begin a transaction + async with self.conn.execute("BEGIN TRANSACTION"): + # Delete all message branches associated with the conversation + await self.conn.execute( + "DELETE FROM message_branches WHERE message_id IN (SELECT id FROM messages WHERE conversation_id = ?)", + (conversation_id,), + ) + # Delete all branches associated with the conversation + await self.conn.execute( + "DELETE FROM branches WHERE conversation_id = ?", + (conversation_id,), + ) + # Delete all messages associated with the conversation + await self.conn.execute( + "DELETE FROM messages WHERE conversation_id = ?", + (conversation_id,), + ) + # Finally, delete the conversation itself + await self.conn.execute( + "DELETE FROM conversations WHERE id = ?", (conversation_id,) + ) + # Commit the transaction + await self.conn.commit() + + async def get_logs( + self, + run_ids: list[UUID], + limit_per_run: int = 10, + ) -> list: + if not run_ids: + raise ValueError("No run ids provided.") + if not self.conn: + raise ValueError( + "Initialize the connection pool before attempting to log." + ) + + cursor = await self.conn.cursor() + placeholders = ",".join(["?" for _ in run_ids]) + query = f""" + SELECT run_id, key, value, timestamp + FROM {self.project_name}_{self.log_table} + WHERE run_id IN ({placeholders}) + ORDER BY timestamp DESC + """ + + params = [str(run_id) for run_id in run_ids] + + await cursor.execute(query, params) + rows = await cursor.fetchall() + + # Post-process the results to limit per run_id and ensure only requested run_ids are included + result = [] + run_id_count = {str(run_id): 0 for run_id in run_ids} + for row in rows: + row_dict = dict(zip([d[0] for d in cursor.description], row)) + row_run_id = row_dict["run_id"] + if ( + row_run_id in run_id_count + and run_id_count[row_run_id] < limit_per_run + ): + row_dict["run_id"] = UUID(row_dict["run_id"]) + result.append(row_dict) + run_id_count[row_run_id] += 1 + return result + + +# class SqlitePersistentLoggingProvider: +# _instance = None +# _is_configured = False +# _config: Optional[PersistentLoggingConfig] = None + +# PERSISTENT_PROVIDERS = { +# "r2r": SqlitePersistentLoggingProvider, +# # TODO - Mark this as deprecated +# "local": SqlitePersistentLoggingProvider, +# } + +# @classmethod +# def get_persistent_logger(cls): +# return cls.PERSISTENT_PROVIDERS[cls._config.provider](cls._config) + +# @classmethod +# def configure(cls, logging_config: PersistentLoggingConfig): +# if logging_config.provider == "local": +# logger.warning( +# "Local logging provider is deprecated. Please use 'r2r' instead." +# ) +# if not cls._is_configured: +# cls._config = logging_config +# cls._is_configured = True +# else: +# raise Exception("SqlitePersistentLoggingProvider is already configured.") + +# @classmethod +# async def log( +# cls, +# run_id: UUID, +# key: str, +# value: str, +# ): +# try: +# async with cls.get_persistent_logger() as provider: +# await provider.log(run_id, key, value) +# except Exception as e: +# logger.error(f"Error logging data {(run_id, key, value)}: {e}") + +# @classmethod +# async def info_log( +# cls, +# run_id: UUID, +# run_type: RunType, +# user_id: UUID, +# ): +# try: +# async with cls.get_persistent_logger() as provider: +# await provider.info_log(run_id, run_type, user_id) +# except Exception as e: +# logger.error( +# f"Error logging info data {(run_id, run_type, user_id)}: {e}" +# ) + +# @classmethod +# async def get_info_logs( +# cls, +# offset: int = 0, +# limit: int = 100, +# run_type_filter: Optional[RunType] = None, +# user_ids: Optional[list[UUID]] = None, +# ) -> list[RunInfoLog]: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_info_logs( +# offset=offset, +# limit=limit, +# run_type_filter=run_type_filter, +# user_ids=user_ids, +# ) + +# @classmethod +# async def get_logs( +# cls, +# run_ids: list[UUID], +# limit_per_run: int = 10, +# ) -> list: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_logs(run_ids, limit_per_run) + +# @classmethod +# async def create_conversation(cls) -> str: +# async with cls.get_persistent_logger() as provider: +# return await provider.create_conversation() + +# @classmethod +# async def get_conversations_overview( +# cls, +# conversation_ids: Optional[list[UUID]] = None, +# offset: int = 0, +# limit: int = 100, +# ) -> list[dict]: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_conversations_overview( +# conversation_ids=conversation_ids, +# offset=offset, +# limit=limit, +# ) + +# @classmethod +# async def add_message( +# cls, +# conversation_id: str, +# content: Message, +# parent_id: Optional[str] = None, +# metadata: Optional[dict] = None, +# ) -> str: +# async with cls.get_persistent_logger() as provider: +# return await provider.add_message( +# conversation_id, content, parent_id, metadata +# ) + +# @classmethod +# async def edit_message( +# cls, message_id: str, new_content: str +# ) -> Tuple[str, str]: +# async with cls.get_persistent_logger() as provider: +# return await provider.edit_message(message_id, new_content) + +# @classmethod +# async def get_conversation( +# cls, conversation_id: str, branch_id: Optional[str] = None +# ) -> list[dict]: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_conversation(conversation_id, branch_id) + +# @classmethod +# async def get_branches_overview(cls, conversation_id: str) -> list[dict]: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_branches_overview(conversation_id) + +# @classmethod +# async def get_next_branch(cls, current_branch_id: str) -> Optional[str]: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_next_branch(current_branch_id) + +# @classmethod +# async def get_prev_branch(cls, current_branch_id: str) -> Optional[str]: +# async with cls.get_persistent_logger() as provider: +# return await provider.get_prev_branch(current_branch_id) + +# @classmethod +# async def branch_at_message(cls, message_id: str) -> str: +# async with cls.get_persistent_logger() as provider: +# return await provider.branch_at_message(message_id) + +# @classmethod +# async def delete_conversation(cls, conversation_id: str): +# async with cls.get_persistent_logger() as provider: +# await provider.delete_conversation(conversation_id) + +# @classmethod +# async def close(cls): +# async with cls.get_persistent_logger() as provider: +# await provider.close() diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 2494d04f5..1558cf174 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -95,7 +95,7 @@ class Document(R2RSerializable): id: UUID = Field(default_factory=uuid4) collection_ids: list[UUID] user_id: UUID - type: DocumentType + document_type: DocumentType metadata: dict class Config: @@ -152,7 +152,7 @@ class DocumentInfo(R2RSerializable): id: UUID collection_ids: list[UUID] user_id: UUID - type: DocumentType + document_type: DocumentType metadata: dict title: Optional[str] = None version: str @@ -171,7 +171,7 @@ def convert_to_db_entry(self): "document_id": self.id, "collection_ids": self.collection_ids, "user_id": self.user_id, - "type": self.type, + "type": self.document_type, "metadata": json.dumps(self.metadata), "title": self.title or "N/A", "version": self.version, diff --git a/py/tests/conftest.py b/py/tests/conftest.py index a72cca8b5..1444081e6 100644 --- a/py/tests/conftest.py +++ b/py/tests/conftest.py @@ -13,7 +13,7 @@ CompletionConfig, DatabaseConfig, EmbeddingConfig, - LoggingConfig, + PersistentLoggingConfig, SqlitePersistentLoggingProvider, Vector, VectorEntry, @@ -114,7 +114,7 @@ async def postgres_db_provider( id=UUID("9fbe403b-c11c-5aae-8ade-ef22980c3ad1"), collection_ids=[UUID("122fdf6a-e116-546b-a8f6-e4cb2e2c0a09")], user_id=UUID("00000000-0000-0000-0000-000000000003"), - type=DocumentType.PDF, + document_type=DocumentType.PDF, metadata={}, title="Test Document for KG", version="1.0", @@ -214,7 +214,7 @@ async def local_logging_provider(app_config): unique_id = str(uuid.uuid4()) logging_path = f"test_{unique_id}.sqlite" provider = SqlitePersistentLoggingProvider( - LoggingConfig(logging_path=logging_path, app=app_config) + PersistentLoggingConfig(logging_path=logging_path, app=app_config) ) await provider._init() yield provider diff --git a/py/tests/core/pipelines/test_pipeline_logic.py b/py/tests/core/pipelines/test_pipeline_logic.py index 2268d7b27..91c78fd36 100644 --- a/py/tests/core/pipelines/test_pipeline_logic.py +++ b/py/tests/core/pipelines/test_pipeline_logic.py @@ -3,13 +3,12 @@ import pytest -from core import AsyncPipe, AsyncPipeline, PipeType +from core import AsyncPipe, AsyncPipeline class MultiplierPipe(AsyncPipe): def __init__(self, multiplier=1, delay=0, name="multiplier_pipe"): super().__init__( - type=PipeType.OTHER, config=self.PipeConfig(name=name), ) self.multiplier = multiplier @@ -38,7 +37,6 @@ async def _run_logic( class FanOutPipe(AsyncPipe): def __init__(self, multiplier=1, delay=0, name="fan_out_pipe"): super().__init__( - type=PipeType.OTHER, config=self.PipeConfig(name=name), ) self.multiplier = multiplier @@ -64,7 +62,6 @@ async def _run_logic( class FanInPipe(AsyncPipe): def __init__(self, delay=0, name="fan_in_pipe"): super().__init__( - type=PipeType.OTHER, config=self.PipeConfig(name=name), ) self.delay = delay diff --git a/py/tests/core/providers/database/relational/test_collection_db.py b/py/tests/core/providers/database/relational/test_collection_db.py index 7eef24edb..924c46d49 100644 --- a/py/tests/core/providers/database/relational/test_collection_db.py +++ b/py/tests/core/providers/database/relational/test_collection_db.py @@ -113,7 +113,7 @@ async def test_assign_and_remove_document_from_collection( id=document_id, collection_ids=[], user_id=UUID("00000000-0000-0000-0000-000000000002"), - type=DocumentType.PDF, + document_type=DocumentType.PDF, metadata={}, version="v1", size_in_bytes=0, diff --git a/py/tests/core/providers/database/relational/test_document_db.py b/py/tests/core/providers/database/relational/test_document_db.py index 4c7eef705..89f5822ca 100644 --- a/py/tests/core/providers/database/relational/test_document_db.py +++ b/py/tests/core/providers/database/relational/test_document_db.py @@ -24,7 +24,7 @@ async def test_upsert_documents_overview(temporary_postgres_db_provider): id=UUID("00000000-0000-0000-0000-000000000001"), collection_ids=[UUID("00000000-0000-0000-0000-000000000002")], user_id=UUID("00000000-0000-0000-0000-000000000003"), - type=DocumentType.PDF, + document_type=DocumentType.PDF, metadata={}, title="Test Document", version="1.0", @@ -45,7 +45,7 @@ async def test_upsert_documents_overview(temporary_postgres_db_provider): assert inserted_document.id == document_info.id assert inserted_document.collection_ids == document_info.collection_ids assert inserted_document.user_id == document_info.user_id - assert inserted_document.type == document_info.type + assert inserted_document.type == document_info.document_type assert inserted_document.metadata == document_info.metadata assert inserted_document.title == document_info.title assert inserted_document.version == document_info.version @@ -78,7 +78,7 @@ async def test_delete_from_documents_overview(temporary_postgres_db_provider): id=UUID("00000000-0000-0000-0000-000000000001"), collection_ids=[UUID("00000000-0000-0000-0000-000000000002")], user_id=UUID("00000000-0000-0000-0000-000000000003"), - type=DocumentType.PDF, + document_type=DocumentType.PDF, metadata={}, title="Test Document", version="1.0", @@ -107,7 +107,7 @@ async def test_get_documents_overview(temporary_postgres_db_provider): id=UUID("00000000-0000-0000-0000-000000000001"), collection_ids=[UUID("00000000-0000-0000-0000-000000000002")], user_id=UUID("00000000-0000-0000-0000-000000000003"), - type=DocumentType.PDF, + document_type=DocumentType.PDF, metadata={}, title="Test Document 1", version="1.0", @@ -119,7 +119,7 @@ async def test_get_documents_overview(temporary_postgres_db_provider): id=UUID("00000000-0000-0000-0000-000000000004"), collection_ids=[UUID("00000000-0000-0000-0000-000000000002")], user_id=UUID("00000000-0000-0000-0000-000000000003"), - type=DocumentType.DOCX, + document_type=DocumentType.DOCX, metadata={}, title="Test Document 2", version="1.0", diff --git a/py/tests/core/providers/logging/test_logging_provider.py b/py/tests/core/providers/logging/test_logging_provider.py index dbd30fffa..5e2cee8c6 100644 --- a/py/tests/core/providers/logging/test_logging_provider.py +++ b/py/tests/core/providers/logging/test_logging_provider.py @@ -6,7 +6,7 @@ import pytest from core import ( - LoggingConfig, + PersistentLoggingConfig, SqlitePersistentLoggingProvider, generate_run_id, )