Skip to content

Commit

Permalink
modify messages
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Apr 1, 2024
1 parent 8923508 commit 2c461e6
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 12 deletions.
18 changes: 18 additions & 0 deletions impl/astra_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ def get_assistant(self, id):
logger.info(f"parsed assistant from row: {assistant}")
return assistant


def delete_by_pk(self, key, value, table):
query_string = f"""
DELETE FROM {CASSANDRA_KEYSPACE}.{table} WHERE {key} = ?;
Expand All @@ -854,6 +855,23 @@ def delete_by_pk(self, key, value, table):
self.session.execute(bound)
return True


def delete_by_pks(self, keys, values, table):
query_string = f"DELETE FROM {CASSANDRA_KEYSPACE}.{table} WHERE "
i = 0
for key in keys:
query_string += f"{key} = ?"
if i < len(keys) - 1:
query_string += " AND "
i += 1

statement = self.session.prepare(query_string)
statement.consistency_level = ConsistencyLevel.QUORUM
bound = statement.bind(values)
self.session.execute(bound)
return True


def update_run_status(self, id, thread_id, status):
query_string = f"""
UPDATE {CASSANDRA_KEYSPACE}.runs SET status = ? WHERE id = ? and thread_id = ?;
Expand Down
10 changes: 10 additions & 0 deletions impl/model/modify_message_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional, Annotated

from pydantic import Field, StrictStr

from openapi_server.models.create_message_request import CreateMessageRequest


class ModifyMessageRequest(CreateMessageRequest):
content: Optional[str] = Field(default=None, min_length=1, strict=True, max_length=32768, description="The content of the message.")
role: Optional[StrictStr] = Field(default=None, description="The role of the entity that is creating the message. Currently only `user` is supported.")
30 changes: 26 additions & 4 deletions impl/routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from impl.model.list_messages_stream_response import ListMessagesStreamResponse
from impl.model.message_object import MessageObject
from impl.model.message_stream_response_object import MessageStreamResponseObject
from impl.model.modify_message_request import ModifyMessageRequest
from impl.model.open_ai_file import OpenAIFile
from impl.model.run_object import RunObject
from impl.model.submit_tool_outputs_run_request import SubmitToolOutputsRunRequest
Expand All @@ -43,6 +44,7 @@
from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response
from openapi_server.models.create_message_request import CreateMessageRequest
from openapi_server.models.create_thread_request import CreateThreadRequest
from openapi_server.models.delete_message_response import DeleteMessageResponse
from openapi_server.models.delete_thread_response import DeleteThreadResponse
from openapi_server.models.list_runs_response import ListRunsResponse
from openapi_server.models.message_content_delta_object import MessageContentDeltaObject
Expand All @@ -51,7 +53,6 @@
from openapi_server.models.message_content_text_object_text import (
MessageContentTextObjectText,
)
from openapi_server.models.modify_message_request import ModifyMessageRequest
from openapi_server.models.modify_thread_request import ModifyThreadRequest
from openapi_server.models.run_object_required_action import RunObjectRequiredAction
from openapi_server.models.run_object_required_action_submit_tool_outputs import RunObjectRequiredActionSubmitToolOutputs
Expand Down Expand Up @@ -184,14 +185,35 @@ async def modify_message(
object="thread.message",
created_at=None,
thread_id=thread_id,
role=None,
content=None,
role=modify_message_request.role,
content=[modify_message_request.content],
assistant_id=None,
run_id=None,
file_ids=None,
file_ids=modify_message_request.file_ids,
metadata=modify_message_request.metadata,
)

@router.delete(
"/threads/{thread_id}/messages/{message_id}",
responses={
200: {"model": DeleteMessageResponse, "description": "OK"},
},
tags=["Assistants"],
summary="Delete a message.",
response_model_by_alias=True,
)
async def delete_message(
thread_id: str = Path(..., description="The ID of the thread to delete."),
message_id: str = Path(..., description="The ID of the message to delete."),
astradb: CassandraClient = Depends(verify_db_client),
) -> DeleteMessageResponse:
astradb.delete_by_pks(table="messages", keys=["id", "thread_id"], values=[message_id, thread_id])
return DeleteMessageResponse(
id=message_id,
object="thread.message.deleted",
deleted=True
)


def extractFunctionArguments(content):
pattern = r"\`\`\`.*({.*})\n\`\`\`"
Expand Down
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ boto3 = "^1.29.6"
prometheus-fastapi-instrumentator = "^6.1.0"
google-cloud-aiplatform = "^1.38.0"
google-generativeai = "^0.3.1"
streaming-assistants = "^0.15.0rc3"
streaming-assistants = "^0.15.1"
annotated-types = "^0.6.0"
pydantic-core = "^2.16.3"
pydantic = "^2.6.4"
Expand Down
42 changes: 41 additions & 1 deletion tests/http/test_assistants_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from impl.model.create_assistant_request import CreateAssistantRequest
from impl.model.create_run_request import CreateRunRequest
from impl.model.message_object import MessageObject
from impl.model.modify_message_request import ModifyMessageRequest
from openapi_server.models.assistant_file_object import AssistantFileObject # noqa: F401
from openapi_server.models.assistant_object import AssistantObject # noqa: F401
from openapi_server.models.create_assistant_file_request import CreateAssistantFileRequest # noqa: F401
Expand All @@ -22,7 +23,6 @@
from openapi_server.models.list_run_steps_response import ListRunStepsResponse # noqa: F401
from openapi_server.models.list_runs_response import ListRunsResponse # noqa: F401
from openapi_server.models.message_file_object import MessageFileObject # noqa: F401
from openapi_server.models.modify_message_request import ModifyMessageRequest # noqa: F401
from openapi_server.models.modify_run_request import ModifyRunRequest # noqa: F401
from openapi_server.models.modify_thread_request import ModifyThreadRequest # noqa: F401
from openapi_server.models.run_object import RunObject # noqa: F401
Expand Down Expand Up @@ -469,6 +469,46 @@ def test_modify_message(client: TestClient):
# uncomment below to assert the status code of the HTTP response
assert response.status_code == 200

def test_modify_message_content(client: TestClient):
"""Test case for modify_message
Modifies a message.
"""

message = test_create_message(client)
modify_message_request = {"metadata":{}, "content": "puppies"}

headers = get_headers(MODEL)
response = client.request(
"POST",
"/threads/{thread_id}/messages/{message_id}".format(thread_id=message.thread_id, message_id=message.id),
headers=headers,
json=modify_message_request,
)

logger.info(response)
# uncomment below to assert the status code of the HTTP response
assert response.status_code == 200

def test_delete_message(client: TestClient):
"""Test case for delete_message
Deletes a message.
"""

message = test_create_message(client)

headers = get_headers(MODEL)
response = client.request(
"DELETE",
"/threads/{thread_id}/messages/{message_id}".format(thread_id=message.thread_id, message_id=message.id),
headers=headers,
)

logger.info(response)
# uncomment below to assert the status code of the HTTP response
assert response.status_code == 200


def test_modify_thread(client: TestClient):
"""Test case for modify_thread
Expand Down

0 comments on commit 2c461e6

Please sign in to comment.