Skip to content

Commit

Permalink
PATCH Message Metadata, export as CSV (#1541)
Browse files Browse the repository at this point in the history
* Message feedback WIP

* Allow export of messages to csv

* error message
  • Loading branch information
NolanTrem authored Oct 31, 2024
1 parent 9b25a5e commit 89dd3cd
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 3 deletions.
1 change: 1 addition & 0 deletions js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ let newCollectionId: string;
* - createConversation
* - addMessage
* X updateMessage
* X updateMessageMetadata
* X branchesOverview
* X getNextBranch
* X getPreviousBranch
Expand Down
1 change: 1 addition & 0 deletions js/sdk/__tests__/r2rClientIntegrationUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const baseUrl = "http://localhost:7272";
* - createConversation
* - addMessage
* X updateMessage
* X updateMessageMetadata
* X branchesOverview
* X getNextBranch
* X getPreviousBranch
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion js/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.3.12",
"version": "0.3.13",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
17 changes: 17 additions & 0 deletions js/sdk/src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,23 @@ export class r2rClient {
});
}

/**
* Update the metadata of a message in an existing conversation.
* @param message_id The ID of the message to update.
* @param metadata The updated metadata.
* @returns A promise that resolves to the response from the server.
*/
@feature("updateMessageMetadata")
async updateMessageMetadata(
message_id: string,
metadata: Record<string, any>,
): Promise<Record<string, any>> {
this._ensureAuthenticated();
return this._makeRequest("PATCH", `messages/${message_id}/metadata`, {
data: metadata,
});
}

/**
* Get an overview of branches in a conversation.
* @param conversationId The ID of the conversation to get branches for.
Expand Down
31 changes: 30 additions & 1 deletion py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import mimetypes
from datetime import datetime, timezone
from typing import Any, Optional, Set, Union
from typing import Optional, Set, Union
from uuid import UUID

import psutil
Expand Down Expand Up @@ -879,6 +879,35 @@ async def edit_message(
"new_branch_id": new_branch_id,
}

@self.router.patch("/messages/{message_id}/metadata")
@self.base_endpoint
async def update_message_metadata(
message_id: str = Path(..., description="Message ID"),
metadata: dict = Body(..., description="Metadata to update"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
):
"""Update metadata for a specific message.
The provided metadata will be merged with existing metadata.
New keys will be added, existing keys will be updated.
"""
await self.service.update_message_metadata(message_id, metadata)
return "ok"

@self.router.get("/export/messages")
@self.base_endpoint
async def export_messages(
auth_user=Depends(self.service.providers.auth.auth_wrapper),
):
if not auth_user.is_superuser:
raise R2RException(
"Only an authorized user can call the `export/messages` endpoint.",
403,
)
return await self.service.export_messages_to_csv(
return_type="stream"
)

@self.router.get("/branches_overview/{conversation_id}")
@self.base_endpoint
async def branches_overview(
Expand Down
17 changes: 17 additions & 0 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from uuid import UUID
from fastapi.responses import StreamingResponse

import toml

Expand Down Expand Up @@ -726,6 +727,22 @@ async def edit_message(
message_id, new_content
)

@telemetry_event("updateMessageMetadata")
async def update_message_metadata(
self, message_id: str, metadata: dict, auth_user=None
):
await self.logging_connection.update_message_metadata(
message_id, metadata
)

@telemetry_event("exportMessagesToCSV")
async def export_messages_to_csv(
self, chunk_size: int = 1000, return_type: str = "stream"
) -> Union[StreamingResponse, str]:
return await self.logging_connection.export_messages_to_csv(
chunk_size, return_type
)

@telemetry_event("BranchesOverview")
async def branches_overview(
self, conversation_id: str, auth_user=None
Expand Down
125 changes: 125 additions & 0 deletions py/core/providers/logger/r2r_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from datetime import datetime
from typing import Optional, Tuple
from uuid import UUID
from fastapi.responses import StreamingResponse
import csv
import io

from core.base import Message
from core.base.logger.base import (
Expand Down Expand Up @@ -501,6 +504,128 @@ async def edit_message(
await self.conn.commit()
return new_message_id, new_branch_id

async def update_message_metadata(
self, message_id: str, metadata: dict
) -> None:
"""Update metadata for a specific message."""

if not self.conn:
raise ValueError(
"Initialize the connection pool before attempting to log."
)

try:
await self.conn.execute("BEGIN TRANSACTION")

cursor = await self.conn.execute(
"SELECT metadata FROM messages WHERE id = ?",
(message_id,),
)
row = await cursor.fetchone()
if not row:
raise ValueError(f"Message {message_id} not found")
current_metadata_json = row[0]
current_metadata = (
json.loads(current_metadata_json)
if current_metadata_json
else {}
)

updated_metadata = {**current_metadata, **metadata}
updated_metadata_json = json.dumps(updated_metadata)

await self.conn.execute(
"UPDATE messages SET metadata = ? WHERE id = ?",
(updated_metadata_json, message_id),
)

await self.conn.commit()

except Exception as e:
await self.conn.rollback()
raise e

async def export_messages_to_csv(
self, chunk_size: int = 1000, return_type: str = "stream"
) -> Union[StreamingResponse, str]:
"""
Export messages table to CSV format.
Args:
chunk_size: Number of records to process at once
return_type: Either "stream" or "string"
Returns:
StreamingResponse or string depending on return_type
"""
if not self.conn:
raise ValueError(
"Initialize the connection pool before attempting to log."
)

async def generate_csv():
buffer = io.StringIO()
writer = csv.writer(buffer)

# Write headers
async with self.conn.execute(
"SELECT * FROM messages LIMIT 1"
) as cursor:
column_names = [
description[0] for description in cursor.description
]
writer.writerow(column_names)
yield buffer.getvalue()
buffer.seek(0)
buffer.truncate()

# Stream rows in chunks
offset = 0
while True:
async with self.conn.execute(
"SELECT * FROM messages LIMIT ? OFFSET ?",
(chunk_size, offset),
) as cursor:
rows = await cursor.fetchall()
if not rows:
break

for row in rows:
writer.writerow(row)
chunk_data = buffer.getvalue()
yield chunk_data
buffer.seek(0)
buffer.truncate()

offset += chunk_size

if return_type == "stream":
return StreamingResponse(
generate_csv(),
media_type="text/csv",
headers={
"Content-Disposition": "attachment; filename=messages.csv"
},
)
else:
# For string return, accumulate all data
csv_data = io.StringIO()
writer = csv.writer(csv_data)

async with self.conn.execute(
"SELECT * FROM messages LIMIT 1"
) as cursor:
column_names = [
description[0] for description in cursor.description
]
writer.writerow(column_names)

async with self.conn.execute("SELECT * FROM messages") as cursor:
rows = await cursor.fetchall()
writer.writerows(rows)

return csv_data.getvalue()

async def get_conversation(
self, conversation_id: str, branch_id: Optional[str] = None
) -> Tuple[str, list[Message]]:
Expand Down
19 changes: 19 additions & 0 deletions py/sdk/mixins/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,25 @@ async def update_message(
"PUT", f"update_message/{message_id}", data=message
)

async def update_message_metadata(
self,
message_id: str,
metadata: dict[str, Any],
) -> dict:
"""
Update the metadata of a message.
Args:
message_id (str): The ID of the message to update.
metadata (dict[str, Any]): The metadata to update.
Returns:
dict: The response from the server.
"""
return await self._make_request( # type: ignore
"PATCH", f"messages/{message_id}/metadata", data=metadata
)

async def branches_overview(
self,
conversation_id: Union[str, UUID],
Expand Down

0 comments on commit 89dd3cd

Please sign in to comment.