diff --git a/impl/astra_vector.py b/impl/astra_vector.py index 82ed5a5..ec04d45 100644 --- a/impl/astra_vector.py +++ b/impl/astra_vector.py @@ -843,6 +843,7 @@ def get_assistant(self, id): logger.info(f"parsed assistant from row: {assistant}") return assistant + def delete_by_pk(self, key, value, table): query_string = f""" DELETE FROM {CASSANDRA_KEYSPACE}.{table} WHERE {key} = ?; @@ -854,6 +855,23 @@ def delete_by_pk(self, key, value, table): self.session.execute(bound) return True + + def delete_by_pks(self, keys, values, table): + query_string = f"DELETE FROM {CASSANDRA_KEYSPACE}.{table} WHERE " + i = 0 + for key in keys: + query_string += f"{key} = ?" + if i < len(keys) - 1: + query_string += " AND " + i += 1 + + statement = self.session.prepare(query_string) + statement.consistency_level = ConsistencyLevel.QUORUM + bound = statement.bind(values) + self.session.execute(bound) + return True + + def update_run_status(self, id, thread_id, status): query_string = f""" UPDATE {CASSANDRA_KEYSPACE}.runs SET status = ? WHERE id = ? and thread_id = ?; diff --git a/impl/model/modify_message_request.py b/impl/model/modify_message_request.py new file mode 100644 index 0000000..7a230a1 --- /dev/null +++ b/impl/model/modify_message_request.py @@ -0,0 +1,10 @@ +from typing import Optional, Annotated + +from pydantic import Field, StrictStr + +from openapi_server.models.create_message_request import CreateMessageRequest + + +class ModifyMessageRequest(CreateMessageRequest): + content: Optional[str] = Field(default=None, min_length=1, strict=True, max_length=32768, description="The content of the message.") + role: Optional[StrictStr] = Field(default=None, description="The role of the entity that is creating the message. Currently only `user` is supported.") diff --git a/impl/routes/threads.py b/impl/routes/threads.py index e85ede5..27fab63 100644 --- a/impl/routes/threads.py +++ b/impl/routes/threads.py @@ -35,6 +35,7 @@ from impl.model.list_messages_stream_response import ListMessagesStreamResponse from impl.model.message_object import MessageObject from impl.model.message_stream_response_object import MessageStreamResponseObject +from impl.model.modify_message_request import ModifyMessageRequest from impl.model.open_ai_file import OpenAIFile from impl.model.run_object import RunObject from impl.model.submit_tool_outputs_run_request import SubmitToolOutputsRunRequest @@ -43,6 +44,7 @@ from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response from openapi_server.models.create_message_request import CreateMessageRequest from openapi_server.models.create_thread_request import CreateThreadRequest +from openapi_server.models.delete_message_response import DeleteMessageResponse from openapi_server.models.delete_thread_response import DeleteThreadResponse from openapi_server.models.list_runs_response import ListRunsResponse from openapi_server.models.message_content_delta_object import MessageContentDeltaObject @@ -51,7 +53,6 @@ from openapi_server.models.message_content_text_object_text import ( MessageContentTextObjectText, ) -from openapi_server.models.modify_message_request import ModifyMessageRequest from openapi_server.models.modify_thread_request import ModifyThreadRequest from openapi_server.models.run_object_required_action import RunObjectRequiredAction from openapi_server.models.run_object_required_action_submit_tool_outputs import RunObjectRequiredActionSubmitToolOutputs @@ -184,14 +185,35 @@ async def modify_message( object="thread.message", created_at=None, thread_id=thread_id, - role=None, - content=None, + role=modify_message_request.role, + content=[modify_message_request.content], assistant_id=None, run_id=None, - file_ids=None, + file_ids=modify_message_request.file_ids, metadata=modify_message_request.metadata, ) +@router.delete( + "/threads/{thread_id}/messages/{message_id}", + responses={ + 200: {"model": DeleteMessageResponse, "description": "OK"}, + }, + tags=["Assistants"], + summary="Delete a message.", + response_model_by_alias=True, +) +async def delete_message( + thread_id: str = Path(..., description="The ID of the thread to delete."), + message_id: str = Path(..., description="The ID of the message to delete."), + astradb: CassandraClient = Depends(verify_db_client), +) -> DeleteMessageResponse: + astradb.delete_by_pks(table="messages", keys=["id", "thread_id"], values=[message_id, thread_id]) + return DeleteMessageResponse( + id=message_id, + object="thread.message.deleted", + deleted=True + ) + def extractFunctionArguments(content): pattern = r"\`\`\`.*({.*})\n\`\`\`" diff --git a/poetry.lock b/poetry.lock index 6f7a602..598ed04 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2746,20 +2746,20 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7 [[package]] name = "streaming-assistants" -version = "0.15.0rc3" +version = "0.15.1" description = "Streaming enabled Assistants API" optional = false -python-versions = ">=3.10,<4.0" +python-versions = "<4.0,>=3.10" files = [ - {file = "streaming_assistants-0.15.0rc3-py3-none-any.whl", hash = "sha256:0919ef6f973492d4c4e0c4beda39daa3502c2de45f33b9720584d6514814d040"}, - {file = "streaming_assistants-0.15.0rc3.tar.gz", hash = "sha256:5e4366e654b101804a08568fb9245322ecc152f032aa0f9d3ba6c2dc8e4ec99c"}, + {file = "streaming_assistants-0.15.1-py3-none-any.whl", hash = "sha256:bfe3f9e3fa915fa99d25ff1c2c5f53ff547b43e2b304e6426dca76e5f30d18b5"}, + {file = "streaming_assistants-0.15.1.tar.gz", hash = "sha256:16f9376902175888d2af0bb5e5b10edd2474735d6e34e0e899962f80a50aea08"}, ] [package.dependencies] boto3 = ">=1.34.31,<2.0.0" google-generativeai = ">=0.3.2,<0.4.0" httpx = ">=0.26.0,<0.27.0" -litellm = ">=1.20.6,<2.0.0" +litellm = ">=1.33.4,<2.0.0" openai = ">=1.14.0,<2.0.0" [[package]] @@ -3563,4 +3563,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10.12,<3.12" -content-hash = "a23017d16f3755037b1b962114b2143ce5c031a663dd696eea28f9199f80aed9" +content-hash = "68321ba48a46f9cb46120239b2f0c67d60a8aeb774787715dcc2eb518f2c6845" diff --git a/pyproject.toml b/pyproject.toml index 97b5048..b9fa0b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ boto3 = "^1.29.6" prometheus-fastapi-instrumentator = "^6.1.0" google-cloud-aiplatform = "^1.38.0" google-generativeai = "^0.3.1" -streaming-assistants = "^0.15.0rc3" +streaming-assistants = "^0.15.1" annotated-types = "^0.6.0" pydantic-core = "^2.16.3" pydantic = "^2.6.4" diff --git a/tests/http/test_assistants_api.py b/tests/http/test_assistants_api.py index 1e9e727..fdcb79d 100644 --- a/tests/http/test_assistants_api.py +++ b/tests/http/test_assistants_api.py @@ -7,6 +7,7 @@ from impl.model.create_assistant_request import CreateAssistantRequest from impl.model.create_run_request import CreateRunRequest from impl.model.message_object import MessageObject +from impl.model.modify_message_request import ModifyMessageRequest from openapi_server.models.assistant_file_object import AssistantFileObject # noqa: F401 from openapi_server.models.assistant_object import AssistantObject # noqa: F401 from openapi_server.models.create_assistant_file_request import CreateAssistantFileRequest # noqa: F401 @@ -22,7 +23,6 @@ from openapi_server.models.list_run_steps_response import ListRunStepsResponse # noqa: F401 from openapi_server.models.list_runs_response import ListRunsResponse # noqa: F401 from openapi_server.models.message_file_object import MessageFileObject # noqa: F401 -from openapi_server.models.modify_message_request import ModifyMessageRequest # noqa: F401 from openapi_server.models.modify_run_request import ModifyRunRequest # noqa: F401 from openapi_server.models.modify_thread_request import ModifyThreadRequest # noqa: F401 from openapi_server.models.run_object import RunObject # noqa: F401 @@ -469,6 +469,46 @@ def test_modify_message(client: TestClient): # uncomment below to assert the status code of the HTTP response assert response.status_code == 200 +def test_modify_message_content(client: TestClient): + """Test case for modify_message + + Modifies a message. + """ + + message = test_create_message(client) + modify_message_request = {"metadata":{}, "content": "puppies"} + + headers = get_headers(MODEL) + response = client.request( + "POST", + "/threads/{thread_id}/messages/{message_id}".format(thread_id=message.thread_id, message_id=message.id), + headers=headers, + json=modify_message_request, + ) + + logger.info(response) + # uncomment below to assert the status code of the HTTP response + assert response.status_code == 200 + +def test_delete_message(client: TestClient): + """Test case for delete_message + + Deletes a message. + """ + + message = test_create_message(client) + + headers = get_headers(MODEL) + response = client.request( + "DELETE", + "/threads/{thread_id}/messages/{message_id}".format(thread_id=message.thread_id, message_id=message.id), + headers=headers, + ) + + logger.info(response) + # uncomment below to assert the status code of the HTTP response + assert response.status_code == 200 + def test_modify_thread(client: TestClient): """Test case for modify_thread