Skip to content

Commit

Permalink
Implement Permissions for Conversations (#1545)
Browse files Browse the repository at this point in the history
* Implement permissions for conversations

* make single a list
  • Loading branch information
NolanTrem authored Oct 31, 2024
1 parent 23372cd commit 9b25a5e
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 25 deletions.
51 changes: 49 additions & 2 deletions py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,28 @@ async def documents_in_collection_app(
@self.base_endpoint
async def conversations_overview_app(
conversation_ids: list[str] = Query([]),
user_ids: Optional[list[str]] = Query(None),
offset: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
limit: int = Query(100, ge=-1, le=1000),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedConversationsOverviewResponse:
conversation_uuids = [
UUID(conversation_id) for conversation_id in conversation_ids
]

if auth_user.is_superuser:
user_ids = [UUID(uid) for uid in user_ids] if user_ids else None # type: ignore
else:
if user_ids:
raise R2RException(
message="Non-superusers cannot query other users' conversations",
status_code=403,
)
user_ids = [auth_user.id]

conversations_overview_response = (
await self.service.conversations_overview(
user_ids=user_ids,
conversation_ids=conversation_uuids,
offset=offset,
limit=limit,
Expand All @@ -797,6 +810,17 @@ async def get_conversation(
branch_id: str = Query(None, description="Branch ID"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedConversationResponse:

if not auth_user.is_superuser:
has_access = await self.service.verify_conversation_access(
conversation_id, auth_user.id
)
if not has_access:
raise R2RException(
message="You do not have access to this conversation",
status_code=403,
)

result = await self.service.get_conversation(
conversation_id,
branch_id,
Expand All @@ -808,7 +832,9 @@ async def get_conversation(
async def create_conversation(
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> dict:
return await self.service.create_conversation()
return await self.service.create_conversation(
user_id=auth_user.id if auth_user else None
)

@self.router.post("/add_message/{conversation_id}")
@self.base_endpoint
Expand All @@ -821,6 +847,16 @@ async def add_message(
metadata: Optional[dict] = Body(None, description="Metadata"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> dict:
if not auth_user.is_superuser:
has_access = await self.service.verify_conversation_access(
conversation_id, auth_user.id
)
if not has_access:
raise R2RException(
message="You do not have access to this conversation",
status_code=403,
)

message_id = await self.service.add_message(
conversation_id, message, parent_id, metadata
)
Expand All @@ -833,6 +869,8 @@ async def edit_message(
message: str = Body(..., description="New content"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> dict:
# TODO: Add a check to see if the user has access to the message

new_message_id, new_branch_id = await self.service.edit_message(
message_id, message
)
Expand All @@ -847,6 +885,15 @@ async def branches_overview(
conversation_id: str = Path(..., description="Conversation ID"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> dict:
if not auth_user.is_superuser:
has_access = await self.service.verify_conversation_access(
conversation_id, auth_user.id
)
if not has_access:
raise R2RException(
message="You do not have access to this conversation's branches",
status_code=403,
)
branches = await self.service.branches_overview(conversation_id)
return {"branches": branches}

Expand Down
17 changes: 15 additions & 2 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,20 +674,33 @@ async def get_conversation(
conversation_id, branch_id
)

async def verify_conversation_access(
self, conversation_id: str, user_id: UUID
) -> bool:
return await self.logging_connection.verify_conversation_access(
conversation_id, user_id
)

@telemetry_event("CreateConversation")
async def create_conversation(self, auth_user=None) -> str:
return await self.logging_connection.create_conversation()
async def create_conversation(
self, user_id: Optional[UUID] = None, auth_user=None
) -> str:
return await self.logging_connection.create_conversation(
user_id=user_id
)

@telemetry_event("ConversationsOverview")
async def conversations_overview(
self,
conversation_ids: Optional[list[UUID]] = None,
user_ids: Optional[UUID | list[UUID]] = None,
offset: int = 0,
limit: int = 100,
auth_user=None,
) -> dict[str, Union[list[dict], int]]:
return await self.logging_connection.get_conversations_overview(
conversation_ids=conversation_ids,
user_ids=user_ids,
offset=offset,
limit=limit,
)
Expand Down
113 changes: 92 additions & 21 deletions py/core/providers/logger/r2r_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from datetime import datetime
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
from uuid import UUID

from core.base import Message
Expand All @@ -14,8 +14,6 @@
RunType,
)

from ..database.postgres import PostgresDBProvider

logger = logging.getLogger()


Expand Down Expand Up @@ -69,7 +67,9 @@ async def initialize(self):
"""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
user_id UUID,
created_at REAL
name TEXT
);
CREATE TABLE IF NOT EXISTS messages (
Expand Down Expand Up @@ -101,6 +101,24 @@ async def initialize(self):
);
"""
)

async with self.conn.execute(
"PRAGMA table_info(conversations);"
) as cursor:
columns = await cursor.fetchall()
column_names = [col[1] for col in columns]

# Add 'user_id' column if it doesn't exist
if "user_id" not in column_names:
await self.conn.execute(
"ALTER TABLE conversations ADD COLUMN user_id TEXT;"
)
# Add 'name' column if it doesn't exist
if "name" not in column_names:
await self.conn.execute(
"ALTER TABLE conversations ADD COLUMN name TEXT;"
)

await self.conn.commit()

async def __aenter__(self):
Expand Down Expand Up @@ -199,7 +217,11 @@ async def get_info_logs(
for row in rows
]

async def create_conversation(self) -> str:
async def create_conversation(
self,
user_id: Optional[UUID] = None,
name: Optional[str] = None,
) -> str:
if not self.conn:
raise ValueError(
"Initialize the connection pool before attempting to log."
Expand All @@ -209,24 +231,56 @@ async def create_conversation(self) -> str:
created_at = datetime.utcnow().timestamp()

await self.conn.execute(
"INSERT INTO conversations (id, created_at) VALUES (?, ?)",
(conversation_id, created_at),
"""
INSERT INTO conversations (id, user_id, created_at, name)
VALUES (?, ?, ?, ?)
""",
(
conversation_id,
str(user_id) if user_id else None,
created_at,
name,
),
)
await self.conn.commit()
return conversation_id

async def verify_conversation_access(
self, conversation_id: str, user_id: UUID
) -> bool:

if not self.conn:
raise ValueError("Connection pool not initialized.")

async with self.conn.execute(
"""
SELECT 1 FROM conversations
WHERE id = ? AND (user_id IS NULL OR user_id = ?)
""",
(conversation_id, str(user_id)),
) as cursor:
return await cursor.fetchone() is not None

async def get_conversations_overview(
self,
conversation_ids: Optional[list[UUID]] = None,
user_ids: Optional[UUID | list[UUID]] = None,
offset: int = 0,
limit: int = -1,
) -> dict[str, Union[list[dict], int]]:
"""Get an overview of conversations, optionally filtered by conversation IDs, with pagination."""
) -> dict[str, list[dict] | int]:
"""
Get conversations overview with pagination.
If user_ids is None, returns all conversations (superuser access)
If user_ids is a single UUID, returns conversations for that user
If user_ids is a list of UUIDs, returns conversations for those users
"""
query = """
WITH conversation_overview AS (
SELECT c.id, c.created_at
SELECT c.id, c.created_at, c.user_id, c.name
FROM conversations c
{where_clause}
WHERE 1=1
{user_where_clause}
{conversation_where_clause}
),
counted_overview AS (
SELECT *, COUNT(*) OVER() AS total_entries
Expand All @@ -237,34 +291,51 @@ async def get_conversations_overview(
LIMIT ? OFFSET ?
"""

where_clause = (
f"WHERE c.id IN ({','.join(['?' for _ in conversation_ids])})"
if conversation_ids
else ""
)
query = query.format(where_clause=where_clause)
params = []

if user_ids is None:
user_where_clause = ""
elif isinstance(user_ids, UUID):
user_where_clause = "AND c.user_id = ?"
params.append(str(user_ids))
else:
user_where_clause = (
f"AND c.user_id IN ({','.join(['?' for _ in user_ids])})"
)
params.extend(str(uid) for uid in user_ids)

params: list = []
if conversation_ids:
params.extend(conversation_ids)
params.extend((limit if limit != -1 else -1, offset))
conversation_where_clause = (
f"AND c.id IN ({','.join(['?' for _ in conversation_ids])})"
)
params.extend(str(cid) for cid in conversation_ids)
else:
conversation_where_clause = ""

params.extend([str(limit) if limit != -1 else "-1", str(offset)])

query = query.format(
user_where_clause=user_where_clause,
conversation_where_clause=conversation_where_clause,
)

if not self.conn:
raise ValueError(
"Initialize the connection pool before attempting to log."
"Initialize the connection pool before attempting to query."
)

async with self.conn.execute(query, params) as cursor:
results = await cursor.fetchall()

if not results:
logger.info("No conversations found.")
return {"results": [], "total_entries": 0}

conversations = [
{
"conversation_id": row[0],
"created_at": row[1],
"user_id": UUID(row[2]) if row[2] else None,
"name": row[3] or None,
}
for row in results
]
Expand Down
2 changes: 2 additions & 0 deletions py/shared/api/models/management/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class CollectionOverviewResponse(BaseModel):
class ConversationOverviewResponse(BaseModel):
conversation_id: UUID
created_at: datetime
user_id: Optional[UUID] = None
name: Optional[str] = None


class VerificationResult(BaseModel):
Expand Down

0 comments on commit 9b25a5e

Please sign in to comment.