Skip to content
This repository has been archived by the owner on Mar 30, 2024. It is now read-only.

Commit

Permalink
CONTEXT_LENGTH default to 4096, and warning for context, add 422 lo…
Browse files Browse the repository at this point in the history
…gger, refactor prompt template (#70)
  • Loading branch information
chenhunghan authored Sep 23, 2023
1 parent 3a6fcaf commit 292250a
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 50 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/smoke_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ jobs:
echo "REPLY=$REPLY" >> $GITHUB_ENV
- if: always()
run: |
kubectl logs --tail=20 --selector app.kubernetes.io/name=$LLAMA_HELM_RELEASE_NAME -n $HELM_NAMESPACE
kubectl logs --tail=200 --selector app.kubernetes.io/name=$LLAMA_HELM_RELEASE_NAME -n $HELM_NAMESPACE
gpt-neox-smoke-test:
runs-on: ubuntu-latest
needs: build-image
Expand Down Expand Up @@ -205,7 +205,7 @@ jobs:
openai -k "sk-fake" -b http://localhost:$GPT_NEOX_SVC_PORT/v1 -vvvvv api completions.create -m $GPT_NEOX_MODEL_FILE -p "A function adding 1 to 1 in Python."
- if: always()
run: |
kubectl logs --tail=20 --selector app.kubernetes.io/name=$GPT_NEOX_HELM_RELEASE_NAME -n $HELM_NAMESPACE
kubectl logs --tail=200 --selector app.kubernetes.io/name=$GPT_NEOX_HELM_RELEASE_NAME -n $HELM_NAMESPACE
starcoder-smoke-test:
runs-on: ubuntu-latest
needs: build-image
Expand Down Expand Up @@ -273,4 +273,4 @@ jobs:
openai -k "sk-fake" -b http://localhost:$STARCODER_SVC_PORT/v1 -vvvvv api completions.create -m $STARCODER_MODEL_FILE -p "def fibonnaci"
- if: always()
run: |
kubectl logs --tail=20 --selector app.kubernetes.io/name=$STARCODER_HELM_RELEASE_NAME -n $HELM_NAMESPACE
kubectl logs --tail=200 --selector app.kubernetes.io/name=$STARCODER_HELM_RELEASE_NAME -n $HELM_NAMESPACE
4 changes: 2 additions & 2 deletions charts/ialacol/Chart.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
apiVersion: v2
appVersion: 0.11.3
appVersion: 0.11.4
description: A Helm chart for ialacol
name: ialacol
type: application
version: 0.11.3
version: 0.11.4
3 changes: 2 additions & 1 deletion const.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
DEFAULT_MAX_TOKENS = "512"
DEFAULT_CONTEXT_LENGTH = "512"
DEFAULT_CONTEXT_LENGTH = "4096"
DEFAULT_LOG_LEVEL = "INFO"
6 changes: 3 additions & 3 deletions log.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging

from get_env import get_env
from const import DEFAULT_LOG_LEVEL


LOGGING_LEVEL = get_env("LOGGING_LEVEL", "INFO")
LOGGING_LEVEL = get_env("LOGGING_LEVEL", DEFAULT_LOG_LEVEL)

log = logging.getLogger("uvicorn")
try:
log.setLevel(LOGGING_LEVEL)
except ValueError:
log.setLevel("INFO")
log.setLevel(DEFAULT_LOG_LEVEL)
125 changes: 86 additions & 39 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
Union,
Annotated,
)
from fastapi import FastAPI, Depends, HTTPException, Body, Request
from fastapi import FastAPI, Depends, HTTPException, Body, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from ctransformers import LLM, AutoModelForCausalLM, Config
from huggingface_hub import hf_hub_download, snapshot_download
Expand Down Expand Up @@ -64,6 +66,13 @@ def set_loading_model(boolean: bool):

app = FastAPI()

# https://github.com/tiangolo/fastapi/issues/3361
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
log.error("%s: %s", request, exc_str)
content = {"status_code": 10422, "message": exc_str, "data": None}
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)

@app.on_event("startup")
async def startup_event():
Expand Down Expand Up @@ -271,48 +280,80 @@ async def chat_completions(
log.warning(
"n, logit_bias, user, presence_penalty and frequency_penalty are not supporte."
)
default_assistant_start = "### Assistant: "
default_assistant_end = ""
default_user_start = "### Human: "
default_user_end = ""
default_system = ""

if "llama" in body.model:
default_assistant_start = "ASSISTANT: \n"
default_user_start = "USER: "
default_user_end = "\n"
default_system = "SYSTEM: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
system_start = ""
system = "You are a helpful assistant."
system_end = ""
user_start = ""
user_end = ""
assistant_start = ""
assistant_end = ""

# https://huggingface.co/blog/llama2#how-to-prompt-llama-2
# https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/discussions/3
if "llama-2" in body.model.lower() and "chat" in body.model.lower():
system_start = "<s>[INST] <<SYS>>\n"
system = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
system_end = "<</SYS>>\n\n"
assistant_start = " "
assistant_end = " </s><s>[INST] "
user_start = ""
user_end = " [/INST]"
# For most instruct fine-tuned models using Alpaca prompt template
# Although instruct fine-tuned models are not tuned for chat, they can be to generate response as if chatting, using Alpaca
# prompt template likely gives better results than using the default prompt template
# See https://github.com/tatsu-lab/stanford_alpaca#data-release
if "instruct" in body.model:
default_assistant_start = "### Response:"
default_user_start = "### Instruction: "
default_user_end = "\n\n"
default_system = "Below is an instruction that describes a task. Write a response that appropriately completes the request\n\n"
if "starchat" in body.model:
if "instruct" in body.model.lower():
system_start = ""
system = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_end = ""
assistant_start = "### Response:"
assistant_end = ""
user_start = "### Instruction:\n"
user_end = "\n\n"
if "starchat" in body.model.lower():
# See https://huggingface.co/blog/starchat-alpha and https://huggingface.co/TheBloke/starchat-beta-GGML#prompt-template
default_assistant_start = "<|assistant|>\n"
default_assistant_end = " <|end|>\n"
default_user_start = "<|user|>\n"
default_user_end = " <|end|>\n"
default_system = "<|system|>\nBelow is a dialogue between a human and an AI assistant called StarChat.<|end|>\n"
if "airoboros" in body.model:
system_start = "<|system|>"
system = (
"Below is a dialogue between a human and an AI assistant called StarChat."
)
system_end = " <|end|>\n"
user_start = "<|user|>"
user_end = " <|end|>\n"
assistant_start = "<|assistant|>\n"
assistant_end = " <|end|>\n"
if "airoboros" in body.model.lower():
# e.g. A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. USER: [prompt] ASSISTANT:
# see https://huggingface.co/jondurbin/airoboros-mpt-30b-gpt4-1p4-five-epochs
default_assistant_start = "ASSISTANT: "
default_user_start = "USER: "
default_system = "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input."
system_start = ""
system = "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input."
system_end = ""
user_start = "USER: "
user_end = ""
assistant_start = "ASSISTANT: "
assistant_end = ""
# If it's a mpt-chat model, we need to add the default prompt
# from https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template
# and https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py#L17
if "mpt" in body.model and "chat" in body.model:
default_assistant_start = "<|im_start|>assistant\n"
default_assistant_end = "<|im_end|>\n"
default_user_start = "<|im_start|>user\n"
default_user_end = "<|im_end|>\n"
default_system = "<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n"
if "mpt" in body.model.lower() and "chat" in body.model.lower():
system_start = "<|im_start|>system\n"
system = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
system_end = "<|im_end|>\n"
assistant_start = "<|im_start|>assistant\n"
assistant_end = "<|im_end|>\n"
user_start = "<|im_start|>user\n"
user_end = "<|im_end|>\n"
# orca mini https://huggingface.co/pankajmathur/orca_mini_3b
if "orca" in body.model.lower() and "mini" in body.model.lower():
system_start = "### System:\n"
system = "You are an AI assistant that follows instruction extremely well. Help as much as you can."
system_end = "\n\n"
assistant_start = "### Response:\n"
assistant_end = ""
# v3 e.g. https://huggingface.co/pankajmathur/orca_mini_v3_13b
if "v3" in body.model.lower():
assistant_start = "### Assistant:\n"
user_start = "### User:\n"
user_end = "\n\n"

user_message = next(
(message for message in body.messages if message.role == "user"), None
Expand All @@ -322,18 +363,24 @@ async def chat_completions(
(message for message in body.messages if message.role == "assistant"), None
)
assistant_message_content = (
f"{default_assistant_start}{assistant_message.content}{default_assistant_end}"
f"{assistant_start}{assistant_message.content}{assistant_end}"
if assistant_message
else ""
)
system_message = next(
(message for message in body.messages if message.role == "system"), None
)
system_message_content = (
system_message.content if system_message else default_system
)

prompt = f"{system_message_content}{assistant_message_content} {default_user_start}{user_message_content}{default_user_end} {default_assistant_start}"
system_message_content = system_message.content if system_message else system
# avoid duplicate user start token in prompt if user message already includes it
if len(user_start) > 0 and user_start in user_message_content:
user_start = ""
# avoid duplicate user end token in prompt if user message already includes it
if len(user_end) > 0 and user_end in user_message_content:
user_end = ""
# avoid duplicate assistant start token in prompt if user message already includes it
if len(assistant_start) > 0 and assistant_start in user_message_content:
assistant_start = ""
prompt = f"{system_start}{system_message_content}{system_end}{assistant_message_content}{user_start}{user_message_content}{user_end}{assistant_start}"
model_name = body.model
llm = request.app.state.llm
if body.stream is True:
Expand Down
4 changes: 2 additions & 2 deletions request_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class CompletionRequestBody(BaseModel):
temperature: Optional[float]
top_p: Optional[float]
stop: Optional[List[str] | str]
stream: bool = Field()
stream: Optional[bool] = Field()
model: str = Field()
# llama.cpp specific parameters
top_k: Optional[int]
Expand Down Expand Up @@ -68,7 +68,7 @@ class ChatCompletionRequestBody(BaseModel):
temperature: Optional[float]
top_p: Optional[float]
stop: Optional[List[str] | str]
stream: bool = Field()
stream: Optional[bool] = Field()
model: str = Field()
# llama.cpp specific parameters
top_k: Optional[int]
Expand Down
52 changes: 52 additions & 0 deletions streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ctransformers import LLM, Config

from log import log
from get_env import get_env
from const import DEFAULT_CONTEXT_LENGTH, DEFAULT_LOG_LEVEL


def completions_streamer(
Expand Down Expand Up @@ -37,8 +39,11 @@ def completions_streamer(
stop = config.stop
log.debug("stop: %s", stop)
log.debug("prompt: %s", prompt)
CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", DEFAULT_CONTEXT_LENGTH))
LOGGING_LEVEL = get_env("LOGGING_LEVEL", DEFAULT_LOG_LEVEL)

log.debug("Streaming from ctransformer instance!")
total_tokens = 0
for token in llm(
prompt,
stream=True,
Expand All @@ -54,6 +59,28 @@ def completions_streamer(
max_new_tokens=max_new_tokens,
stop=stop,
):
if LOGGING_LEVEL == "DEBUG":
# Only track token length if we're in debug mode to avoid overhead
total_tokens = total_tokens + len(token)
# tokens are not necessarily characters, but this is a good enough approximation
if total_tokens > CONTEXT_LENGTH:
log.debug(
"Total token length %s exceeded context length %s",
total_tokens,
CONTEXT_LENGTH,
)
log.debug(
"Try to increase CONTEXT_LENGTH that is currently set to %s to your model's context length",
CONTEXT_LENGTH,
)
log.debug(
"Alternatively, increse REPETITION_PENALTY %s and LAST_N_TOKENS %s AND/OR adjust temperature %s top_k %s top_p %s",
repetition_penalty,
last_n_tokens,
temperature,
top_k,
top_p,
)
log.debug("Streaming token %s", token)
data = json.dumps(
{
Expand Down Expand Up @@ -123,8 +150,11 @@ def chat_completions_streamer(
stop = config.stop
log.debug("stop: %s", stop)
log.debug("prompt: %s", prompt)
CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", DEFAULT_CONTEXT_LENGTH))
LOGGING_LEVEL = get_env("LOGGING_LEVEL", DEFAULT_LOG_LEVEL)

log.debug("Streaming from ctransformer instance")
total_tokens = 0
for token in llm(
prompt,
stream=True,
Expand All @@ -140,6 +170,28 @@ def chat_completions_streamer(
max_new_tokens=max_new_tokens,
stop=stop,
):
if LOGGING_LEVEL == "DEBUG":
# Only track token length if we're in debug mode to avoid overhead
total_tokens = total_tokens + len(token)
# tokens are not necessarily characters, but this is a good enough approximation
if total_tokens > CONTEXT_LENGTH:
log.debug(
"Total token length %s exceeded context length %s",
total_tokens,
CONTEXT_LENGTH,
)
log.debug(
"Try to increase CONTEXT_LENGTH that is currently set to %s to your model's context length",
CONTEXT_LENGTH,
)
log.debug(
"Alternatively, increse REPETITION_PENALTY %s and LAST_N_TOKENS %s AND/OR adjust temperature %s top_k %s top_p %s",
repetition_penalty,
last_n_tokens,
temperature,
top_k,
top_p,
)
log.debug("Streaming token %s", token)
data = json.dumps(
{
Expand Down

0 comments on commit 292250a

Please sign in to comment.