From 5a14a48cc09a381fcb59b0a1f4cad4a0f1bed627 Mon Sep 17 00:00:00 2001 From: "Hung-Han (Henry) Chen" Date: Fri, 3 Nov 2023 15:02:29 +0200 Subject: [PATCH] Support mutiple messages Signed-off-by: Hung-Han (Henry) Chen --- main.py | 81 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index 506637e..00d57dc 100644 --- a/main.py +++ b/main.py @@ -31,9 +31,7 @@ DEFAULT_MODEL_HG_REPO_ID = get_env( "DEFAULT_MODEL_HG_REPO_ID", "TheBloke/Llama-2-7B-Chat-GGML" ) -DEFAULT_MODEL_HG_REPO_REVISION = get_env( - "DEFAULT_MODEL_HG_REPO_REVISION", "main" -) +DEFAULT_MODEL_HG_REPO_REVISION = get_env("DEFAULT_MODEL_HG_REPO_REVISION", "main") DEFAULT_MODEL_FILE = get_env("DEFAULT_MODEL_FILE", "llama-2-7b-chat.ggmlv3.q4_0.bin") log.info("DEFAULT_MODEL_HG_REPO_ID: %s", DEFAULT_MODEL_HG_REPO_ID) @@ -70,13 +68,17 @@ def set_loading_model(boolean: bool): app = FastAPI() + # https://github.com/tiangolo/fastapi/issues/3361 @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") log.error("%s: %s", request, exc_str) content = {"status_code": 10422, "message": exc_str, "data": None} - return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) + return JSONResponse( + content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY + ) + @app.on_event("startup") async def startup_event(): @@ -380,35 +382,48 @@ async def chat_completions( user_start = "GPT4 User: " user_end = "<|end_of_turn|>" - user_message = next( - (message for message in reversed(body.messages) if message.role == "user"), None - ) - user_message_content = user_message.content if user_message else "" - assistant_message = next( - (message for message in reversed(body.messages) if message.role == "assistant"), None - ) - assistant_message_content = ( - f"{assistant_start}{assistant_message.content}{assistant_end}" - if assistant_message - else "" - ) - system_message = next( - (message for message in body.messages if message.role == "system"), None - ) - system_message_content = system_message.content if system_message else system - # avoid duplicate user start token in prompt if user message already includes it - if len(user_start) > 0 and user_start in user_message_content: - user_start = "" - # avoid duplicate user end token in prompt if user message already includes it - if len(user_end) > 0 and user_end in user_message_content: - user_end = "" - # avoid duplicate assistant start token in prompt if user message already includes it - if len(assistant_start) > 0 and assistant_start in user_message_content: - assistant_start = "" - # avoid duplicate system_start token in prompt if system_message_content already includes it - if len(system_start) > 0 and system_start in system_message_content: - system_start = "" - prompt = f"{system_start}{system_message_content}{system_end}{assistant_message_content}{user_start}{user_message_content}{user_end}{assistant_start}" + prompt = "" + for message in body.messages: + # Check for system message + if message.role == "system": + system_message_content = message.content if message else "" + + # avoid duplicate system_start token in prompt if system_message_content already includes it + if len(system_start) > 0 and system_start in system_message_content: + system_start = "" + # avoid duplicate system_end token in prompt if system_message_content already includes it + if len(system_end) > 0 and system_end in system_message_content: + system_end = "" + prompt = f"{system_start}{system_message_content}{system_end}" + elif message.role == "user": + user_message_content = message.content if message else "" + + # avoid duplicate user start token in prompt if user_message_content already includes it + if len(user_start) > 0 and user_start in user_message_content: + user_start = "" + # avoid duplicate user end token in prompt if user_message_content already includes it + if len(user_end) > 0 and user_end in user_message_content: + user_end = "" + + prompt = f"{prompt}{user_start}{user_message_content}{user_end}" + elif message.role == "assistant": + assistant_message_content = message.content if message else "" + + # avoid duplicate assistant start token in prompt if user message already includes it + if ( + len(assistant_start) > 0 + and assistant_start in assistant_message_content + ): + assistant_start = "" + # avoid duplicate assistant start token in prompt if user message already includes it + if len(assistant_end) > 0 and assistant_end in assistant_message_content: + assistant_end = "" + + prompt = ( + f"{prompt}{assistant_start}{assistant_message_content}{assistant_end}" + ) + + prompt = f"{prompt}{assistant_start}" model_name = body.model llm = request.app.state.llm if body.stream is True: