Skip to content

Commit

Permalink
Add update to conversation, clean up type errors around conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Dec 21, 2024
1 parent 58ee771 commit 0127711
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 26 deletions.
9 changes: 9 additions & 0 deletions js/sdk/__tests__/ConversationsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ describe("r2rClient V3 Collections Integration Tests", () => {
expect(response.results.name).toBe("Test Conversation");
});

test("Update a conversation name", async () => {
const response = await client.conversations.update({
id: conversationId,
name: "Updated Name",
});
expect(response.results).toBeDefined();
expect(response.results.name).toBe("Updated Name");
});

test("Delete a conversation", async () => {
const response = await client.conversations.delete({ id: conversationId });
expect(response.results).toBeDefined();
Expand Down
20 changes: 20 additions & 0 deletions js/sdk/src/v3/clients/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ export class ConversationsClient {
return this.client.makeRequest("GET", `conversations/${options.id}`);
}

/**
* Update an existing conversation.
* @param id The ID of the conversation to update
* @param name The new name of the conversation
* @returns The updated conversation
*/
@feature("conversations.update")
async update(options: {
id: string;
name: string;
}): Promise<WrappedConversationResponse> {
const data: Record<string, any> = {
name: options.name,
};

return this.client.makeRequest("POST", `conversations/${options.id}`, {
data,
});
}

/**
* Delete a conversation.
* @param id The ID of the conversation to delete
Expand Down
51 changes: 42 additions & 9 deletions py/core/database/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ async def create_conversation(
)

return ConversationResponse(
id=str(result["id"]),
id=result["id"],
created_at=result["created_at_epoch"],
user_id=str(user_id) if user_id else None,
user_id=user_id or None,
name=name or None,
)
except Exception as e:
Expand Down Expand Up @@ -92,7 +92,7 @@ async def get_conversations_overview(
) -> dict[str, Any]:
# Construct conditions
conditions = []
params = []
params: list = []
param_index = 1

if user_ids is not None:
Expand Down Expand Up @@ -226,13 +226,13 @@ async def add_message(
status_code=500, message="Failed to insert message."
)

return MessageResponse(id=str(message_id), message=content)
return MessageResponse(id=message_id, message=content)

async def edit_message(
self,
message_id: UUID,
new_content: str | None = None,
additional_metadata: dict = None,
additional_metadata: dict | None = None,
) -> dict[str, Any]:
# Get the original message
query = f"""
Expand Down Expand Up @@ -348,8 +348,6 @@ async def get_conversation(
)

# Retrieve messages in chronological order
# We'll recursively gather messages based on parent_id = NULL as root.
# Since no branching, we simply order by created_at.
msg_query = f"""
SELECT id, content, metadata
FROM {self._get_table_name("messages")}
Expand All @@ -362,14 +360,49 @@ async def get_conversation(

return [
MessageResponse(
id=str(row["id"]),
id=row["id"],
message=Message(**json.loads(row["content"])),
metadata=json.loads(row["metadata"]),
)
for row in results
]

async def delete_conversation(self, conversation_id: UUID):
async def update_conversation(
self, conversation_id: UUID, name: str
) -> ConversationResponse:
try:
# Check if conversation exists
conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
conv_row = await self.connection_manager.fetchrow_query(
conv_query, [conversation_id]
)
if not conv_row:
raise R2RException(
status_code=404,
message=f"Conversation {conversation_id} not found.",
)

update_query = f"""
UPDATE {self._get_table_name('conversations')}
SET name = $1 WHERE id = $2
RETURNING user_id, extract(epoch from created_at) as created_at_epoch
"""
updated_row = await self.connection_manager.fetchrow_query(
update_query, [name, conversation_id]
)
return ConversationResponse(
id=conversation_id,
created_at=updated_row["created_at_epoch"],
user_id=updated_row["user_id"] or None,
name=name,
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to update conversation: {str(e)}",
) from e

async def delete_conversation(self, conversation_id: UUID) -> None:
# Check if conversation exists
conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
conv_row = await self.connection_manager.fetchrow_query(
Expand Down
97 changes: 90 additions & 7 deletions py/core/main/api/v3/conversations_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,91 @@ async def get_conversation(
This endpoint retrieves detailed information about a single conversation identified by its UUID.
"""
conversation = await self.services.management.get_conversation(
str(id)
conversation_id=id
)
return conversation

@self.router.post(
"/conversations/{id}",
summary="Delete conversation",
dependencies=[Depends(self.rate_limit_dependency)],
openapi_extra={
"x-codeSamples": [
{
"lang": "Python",
"source": textwrap.dedent(
"""
from r2r import R2RClient
client = R2RClient("http://localhost:7272")
# when using auth, do client.login(...)
result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
"""
),
},
{
"lang": "JavaScript",
"source": textwrap.dedent(
"""
const { r2rClient } = require("r2r-js");
const client = new r2rClient("http://localhost:7272");
function main() {
const response = await client.conversations.update({
id: "123e4567-e89b-12d3-a456-426614174000",
name: "new_name",
});
}
main();
"""
),
},
{
"lang": "CLI",
"source": textwrap.dedent(
"""
r2r conversations delete 123e4567-e89b-12d3-a456-426614174000
"""
),
},
{
"lang": "cURL",
"source": textwrap.dedent(
"""
curl -X PUT "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: B
"""
),
},
]
},
)
@self.base_endpoint
async def update_conversation(
id: UUID = Path(
...,
description="The unique identifier of the conversation to delete",
),
name: str = Body(
...,
description="The updated name for the conversation",
embed=True,
),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedConversationResponse:
"""
Update an existing conversation.
This endpoint updates the name of an existing conversation identified by its UUID.
"""
return await self.services.management.update_conversation(
conversation_id=id,
name=name,
)

@self.router.delete(
"/conversations/{id}",
summary="Delete conversation",
Expand Down Expand Up @@ -347,7 +428,9 @@ async def delete_conversation(
This endpoint deletes a conversation identified by its UUID.
"""
await self.services.management.delete_conversation(str(id))
await self.services.management.delete_conversation(
conversation_id=id
)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.post(
Expand Down Expand Up @@ -421,7 +504,7 @@ async def add_message(
role: str = Body(
..., description="The role of the message to add"
),
parent_id: Optional[str] = Body(
parent_id: Optional[UUID] = Body(
None, description="The ID of the parent message, if any"
),
metadata: Optional[dict[str, str]] = Body(
Expand All @@ -440,10 +523,10 @@ async def add_message(
raise R2RException("Invalid role", status_code=400)
message = Message(role=role, content=content)
return await self.services.management.add_message(
str(id),
message,
parent_id,
metadata,
conversation_id=id,
content=message,
parent_id=parent_id,
metadata=metadata,
)

@self.router.post(
Expand Down
33 changes: 23 additions & 10 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
R2RException,
RunManager,
User,
ConversationResponse,
)
from core.telemetry.telemetry_decorator import telemetry_event

Expand Down Expand Up @@ -804,26 +805,27 @@ async def delete_prompt(self, name: str) -> dict:
@telemetry_event("GetConversation")
async def get_conversation(
self,
conversation_id: str,
conversation_id: UUID,
auth_user=None,
) -> Tuple[str, list[Message], list[dict]]:
return await self.providers.database.conversations_handler.get_conversation( # type: ignore
conversation_id
conversation_id=conversation_id
)

async def verify_conversation_access(
self, conversation_id: str, user_id: UUID
self, conversation_id: UUID, user_id: UUID
) -> bool:
return await self.providers.database.conversations_handler.verify_conversation_access(
conversation_id, user_id
conversation_id=conversation_id,
user_id=user_id,
)

@telemetry_event("CreateConversation")
async def create_conversation(
self,
user_id: Optional[UUID] = None,
name: Optional[str] = None,
) -> dict:
) -> ConversationResponse:
return await self.providers.database.conversations_handler.create_conversation(
user_id=user_id,
name=name,
Expand All @@ -848,14 +850,17 @@ async def conversations_overview(
@telemetry_event("AddMessage")
async def add_message(
self,
conversation_id: str,
conversation_id: UUID,
content: Message,
parent_id: Optional[str] = None,
parent_id: Optional[UUID] = None,
metadata: Optional[dict] = None,
auth_user=None,
) -> str:
return await self.providers.database.conversations_handler.add_message(
conversation_id, content, parent_id, metadata
conversation_id=conversation_id,
content=content,
parent_id=parent_id,
metadata=metadata,
)

@telemetry_event("EditMessage")
Expand All @@ -874,10 +879,18 @@ async def edit_message(
)
)

@telemetry_event("UpdateConversation")
async def update_conversation(
self, conversation_id: UUID, name: str
) -> ConversationResponse:
return await self.providers.database.conversations_handler.update_conversation(
conversation_id=conversation_id, name=name
)

@telemetry_event("DeleteConversation")
async def delete_conversation(self, conversation_id: str, auth_user=None):
async def delete_conversation(self, conversation_id: UUID) -> None:
await self.providers.database.conversations_handler.delete_conversation(
conversation_id
conversation_id=conversation_id
)

async def get_user_max_documents(self, user_id: UUID) -> int:
Expand Down
26 changes: 26 additions & 0 deletions py/sdk/v3/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ async def retrieve(
version="v3",
)

async def update(
self,
id: str | UUID,
name: str,
) -> WrappedConversationResponse:
"""
Update an existing conversation.
Args:
id (Union[str, UUID]): The ID of the conversation to update
name (str): The new name of the conversation
Returns:
dict: The updated conversation
"""
data: dict[str, Any] = {
"name": name,
}

return await self.client._make_request(
"POST",
f"conversations/{str(id)}",
json=data,
version="v3",
)

async def delete(
self,
id: str | UUID,
Expand Down

0 comments on commit 0127711

Please sign in to comment.