From 9b25a5e6445b363569ca3f30c11b732d94b63ff8 Mon Sep 17 00:00:00 2001 From: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:33:53 -0700 Subject: [PATCH] Implement Permissions for Conversations (#1545) * Implement permissions for conversations * make single a list --- py/core/main/api/management_router.py | 51 ++++++++- py/core/main/services/management_service.py | 17 ++- py/core/providers/logger/r2r_logger.py | 113 +++++++++++++++---- py/shared/api/models/management/responses.py | 2 + 4 files changed, 158 insertions(+), 25 deletions(-) diff --git a/py/core/main/api/management_router.py b/py/core/main/api/management_router.py index 628cf3457..55e3c14fd 100644 --- a/py/core/main/api/management_router.py +++ b/py/core/main/api/management_router.py @@ -769,15 +769,28 @@ async def documents_in_collection_app( @self.base_endpoint async def conversations_overview_app( conversation_ids: list[str] = Query([]), + user_ids: Optional[list[str]] = Query(None), offset: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), + limit: int = Query(100, ge=-1, le=1000), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> WrappedConversationsOverviewResponse: conversation_uuids = [ UUID(conversation_id) for conversation_id in conversation_ids ] + + if auth_user.is_superuser: + user_ids = [UUID(uid) for uid in user_ids] if user_ids else None # type: ignore + else: + if user_ids: + raise R2RException( + message="Non-superusers cannot query other users' conversations", + status_code=403, + ) + user_ids = [auth_user.id] + conversations_overview_response = ( await self.service.conversations_overview( + user_ids=user_ids, conversation_ids=conversation_uuids, offset=offset, limit=limit, @@ -797,6 +810,17 @@ async def get_conversation( branch_id: str = Query(None, description="Branch ID"), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> WrappedConversationResponse: + + if not auth_user.is_superuser: + has_access = await self.service.verify_conversation_access( + conversation_id, auth_user.id + ) + if not has_access: + raise R2RException( + message="You do not have access to this conversation", + status_code=403, + ) + result = await self.service.get_conversation( conversation_id, branch_id, @@ -808,7 +832,9 @@ async def get_conversation( async def create_conversation( auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> dict: - return await self.service.create_conversation() + return await self.service.create_conversation( + user_id=auth_user.id if auth_user else None + ) @self.router.post("/add_message/{conversation_id}") @self.base_endpoint @@ -821,6 +847,16 @@ async def add_message( metadata: Optional[dict] = Body(None, description="Metadata"), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> dict: + if not auth_user.is_superuser: + has_access = await self.service.verify_conversation_access( + conversation_id, auth_user.id + ) + if not has_access: + raise R2RException( + message="You do not have access to this conversation", + status_code=403, + ) + message_id = await self.service.add_message( conversation_id, message, parent_id, metadata ) @@ -833,6 +869,8 @@ async def edit_message( message: str = Body(..., description="New content"), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> dict: + # TODO: Add a check to see if the user has access to the message + new_message_id, new_branch_id = await self.service.edit_message( message_id, message ) @@ -847,6 +885,15 @@ async def branches_overview( conversation_id: str = Path(..., description="Conversation ID"), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> dict: + if not auth_user.is_superuser: + has_access = await self.service.verify_conversation_access( + conversation_id, auth_user.id + ) + if not has_access: + raise R2RException( + message="You do not have access to this conversation's branches", + status_code=403, + ) branches = await self.service.branches_overview(conversation_id) return {"branches": branches} diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 0a01250aa..9f12f97a7 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -674,20 +674,33 @@ async def get_conversation( conversation_id, branch_id ) + async def verify_conversation_access( + self, conversation_id: str, user_id: UUID + ) -> bool: + return await self.logging_connection.verify_conversation_access( + conversation_id, user_id + ) + @telemetry_event("CreateConversation") - async def create_conversation(self, auth_user=None) -> str: - return await self.logging_connection.create_conversation() + async def create_conversation( + self, user_id: Optional[UUID] = None, auth_user=None + ) -> str: + return await self.logging_connection.create_conversation( + user_id=user_id + ) @telemetry_event("ConversationsOverview") async def conversations_overview( self, conversation_ids: Optional[list[UUID]] = None, + user_ids: Optional[UUID | list[UUID]] = None, offset: int = 0, limit: int = 100, auth_user=None, ) -> dict[str, Union[list[dict], int]]: return await self.logging_connection.get_conversations_overview( conversation_ids=conversation_ids, + user_ids=user_ids, offset=offset, limit=limit, ) diff --git a/py/core/providers/logger/r2r_logger.py b/py/core/providers/logger/r2r_logger.py index 637636366..f24343ac4 100644 --- a/py/core/providers/logger/r2r_logger.py +++ b/py/core/providers/logger/r2r_logger.py @@ -3,7 +3,7 @@ import os import uuid from datetime import datetime -from typing import Optional, Tuple, Union +from typing import Optional, Tuple from uuid import UUID from core.base import Message @@ -14,8 +14,6 @@ RunType, ) -from ..database.postgres import PostgresDBProvider - logger = logging.getLogger() @@ -69,7 +67,9 @@ async def initialize(self): """ CREATE TABLE IF NOT EXISTS conversations ( id TEXT PRIMARY KEY, + user_id UUID, created_at REAL + name TEXT ); CREATE TABLE IF NOT EXISTS messages ( @@ -101,6 +101,24 @@ async def initialize(self): ); """ ) + + async with self.conn.execute( + "PRAGMA table_info(conversations);" + ) as cursor: + columns = await cursor.fetchall() + column_names = [col[1] for col in columns] + + # Add 'user_id' column if it doesn't exist + if "user_id" not in column_names: + await self.conn.execute( + "ALTER TABLE conversations ADD COLUMN user_id TEXT;" + ) + # Add 'name' column if it doesn't exist + if "name" not in column_names: + await self.conn.execute( + "ALTER TABLE conversations ADD COLUMN name TEXT;" + ) + await self.conn.commit() async def __aenter__(self): @@ -199,7 +217,11 @@ async def get_info_logs( for row in rows ] - async def create_conversation(self) -> str: + async def create_conversation( + self, + user_id: Optional[UUID] = None, + name: Optional[str] = None, + ) -> str: if not self.conn: raise ValueError( "Initialize the connection pool before attempting to log." @@ -209,24 +231,56 @@ async def create_conversation(self) -> str: created_at = datetime.utcnow().timestamp() await self.conn.execute( - "INSERT INTO conversations (id, created_at) VALUES (?, ?)", - (conversation_id, created_at), + """ + INSERT INTO conversations (id, user_id, created_at, name) + VALUES (?, ?, ?, ?) + """, + ( + conversation_id, + str(user_id) if user_id else None, + created_at, + name, + ), ) await self.conn.commit() return conversation_id + async def verify_conversation_access( + self, conversation_id: str, user_id: UUID + ) -> bool: + + if not self.conn: + raise ValueError("Connection pool not initialized.") + + async with self.conn.execute( + """ + SELECT 1 FROM conversations + WHERE id = ? AND (user_id IS NULL OR user_id = ?) + """, + (conversation_id, str(user_id)), + ) as cursor: + return await cursor.fetchone() is not None + async def get_conversations_overview( self, conversation_ids: Optional[list[UUID]] = None, + user_ids: Optional[UUID | list[UUID]] = None, offset: int = 0, limit: int = -1, - ) -> dict[str, Union[list[dict], int]]: - """Get an overview of conversations, optionally filtered by conversation IDs, with pagination.""" + ) -> dict[str, list[dict] | int]: + """ + Get conversations overview with pagination. + If user_ids is None, returns all conversations (superuser access) + If user_ids is a single UUID, returns conversations for that user + If user_ids is a list of UUIDs, returns conversations for those users + """ query = """ WITH conversation_overview AS ( - SELECT c.id, c.created_at + SELECT c.id, c.created_at, c.user_id, c.name FROM conversations c - {where_clause} + WHERE 1=1 + {user_where_clause} + {conversation_where_clause} ), counted_overview AS ( SELECT *, COUNT(*) OVER() AS total_entries @@ -237,34 +291,51 @@ async def get_conversations_overview( LIMIT ? OFFSET ? """ - where_clause = ( - f"WHERE c.id IN ({','.join(['?' for _ in conversation_ids])})" - if conversation_ids - else "" - ) - query = query.format(where_clause=where_clause) + params = [] + + if user_ids is None: + user_where_clause = "" + elif isinstance(user_ids, UUID): + user_where_clause = "AND c.user_id = ?" + params.append(str(user_ids)) + else: + user_where_clause = ( + f"AND c.user_id IN ({','.join(['?' for _ in user_ids])})" + ) + params.extend(str(uid) for uid in user_ids) - params: list = [] if conversation_ids: - params.extend(conversation_ids) - params.extend((limit if limit != -1 else -1, offset)) + conversation_where_clause = ( + f"AND c.id IN ({','.join(['?' for _ in conversation_ids])})" + ) + params.extend(str(cid) for cid in conversation_ids) + else: + conversation_where_clause = "" + + params.extend([str(limit) if limit != -1 else "-1", str(offset)]) + + query = query.format( + user_where_clause=user_where_clause, + conversation_where_clause=conversation_where_clause, + ) if not self.conn: raise ValueError( - "Initialize the connection pool before attempting to log." + "Initialize the connection pool before attempting to query." ) async with self.conn.execute(query, params) as cursor: results = await cursor.fetchall() if not results: - logger.info("No conversations found.") return {"results": [], "total_entries": 0} conversations = [ { "conversation_id": row[0], "created_at": row[1], + "user_id": UUID(row[2]) if row[2] else None, + "name": row[3] or None, } for row in results ] diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index e8c0bd8e0..0bbbbe5db 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -131,6 +131,8 @@ class CollectionOverviewResponse(BaseModel): class ConversationOverviewResponse(BaseModel): conversation_id: UUID created_at: datetime + user_id: Optional[UUID] = None + name: Optional[str] = None class VerificationResult(BaseModel):