From 774c24d54dfdc98d65e3432efd529b81c1e89b0d Mon Sep 17 00:00:00 2001 From: "Hung-Han (Henry) Chen" Date: Fri, 3 Nov 2023 14:26:08 +0200 Subject: [PATCH] Fallback repetition_penalty to body.frequency_penalty if present Signed-off-by: Hung-Han (Henry) Chen --- get_config.py | 13 ++++++++++--- request_body.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/get_config.py b/get_config.py index 9c08c48..02c1cc3 100644 --- a/get_config.py +++ b/get_config.py @@ -8,6 +8,7 @@ THREADS = int(get_env("THREADS", str(get_default_thread()))) + def get_config( body: CompletionRequestBody | ChatCompletionRequestBody, ) -> Config: @@ -28,8 +29,10 @@ def get_config( # OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-max_tokens MAX_TOKENS = int(get_env("MAX_TOKENS", DEFAULT_MAX_TOKENS)) CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", DEFAULT_CONTEXT_LENGTH)) - if (MAX_TOKENS > CONTEXT_LENGTH): - log.warning("MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH") + if MAX_TOKENS > CONTEXT_LENGTH: + log.warning( + "MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH" + ) # OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-stop STOP = get_env_or_none("STOP") @@ -48,7 +51,11 @@ def get_config( top_p = body.top_p if body.top_p else TOP_P temperature = body.temperature if body.temperature else TEMPERATURE repetition_penalty = ( - body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY + body.frequency_penalty + if body.frequency_penalty + else ( + body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY + ) ) last_n_tokens = body.last_n_tokens if body.last_n_tokens else LAST_N_TOKENS seed = body.seed if body.seed else SEED diff --git a/request_body.py b/request_body.py index d6e3859..ce2999e 100644 --- a/request_body.py +++ b/request_body.py @@ -24,6 +24,7 @@ class CompletionRequestBody(BaseModel): # llama.cpp specific parameters top_k: Optional[int] repetition_penalty: Optional[float] + frequency_penalty: Optional[float] last_n_tokens: Optional[int] seed: Optional[int] batch_size: Optional[int] @@ -32,7 +33,6 @@ class CompletionRequestBody(BaseModel): # ignored or currently unsupported suffix: Any presence_penalty: Any - frequency_penalty: Any echo: Any n: Any logprobs: Any @@ -73,6 +73,7 @@ class ChatCompletionRequestBody(BaseModel): # llama.cpp specific parameters top_k: Optional[int] repetition_penalty: Optional[float] + frequency_penalty: Optional[float] last_n_tokens: Optional[int] seed: Optional[int] batch_size: Optional[int] @@ -83,7 +84,6 @@ class ChatCompletionRequestBody(BaseModel): logit_bias: Any user: Any presence_penalty: Any - frequency_penalty: Any class Config: arbitrary_types_allowed = True