Skip to content
This repository has been archived by the owner on Mar 30, 2024. It is now read-only.

Commit

Permalink
Fallback repetition_penalty to body.frequency_penalty if present
Browse files Browse the repository at this point in the history
Signed-off-by: Hung-Han (Henry) Chen <chenhungh@gmail.com>
  • Loading branch information
chenhunghan committed Nov 3, 2023
1 parent d84b08e commit 774c24d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
13 changes: 10 additions & 3 deletions get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

THREADS = int(get_env("THREADS", str(get_default_thread())))


def get_config(
body: CompletionRequestBody | ChatCompletionRequestBody,
) -> Config:
Expand All @@ -28,8 +29,10 @@ def get_config(
# OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-max_tokens
MAX_TOKENS = int(get_env("MAX_TOKENS", DEFAULT_MAX_TOKENS))
CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", DEFAULT_CONTEXT_LENGTH))
if (MAX_TOKENS > CONTEXT_LENGTH):
log.warning("MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH")
if MAX_TOKENS > CONTEXT_LENGTH:
log.warning(
"MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH"
)
# OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-stop
STOP = get_env_or_none("STOP")

Expand All @@ -48,7 +51,11 @@ def get_config(
top_p = body.top_p if body.top_p else TOP_P
temperature = body.temperature if body.temperature else TEMPERATURE
repetition_penalty = (
body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY
body.frequency_penalty
if body.frequency_penalty
else (
body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY
)
)
last_n_tokens = body.last_n_tokens if body.last_n_tokens else LAST_N_TOKENS
seed = body.seed if body.seed else SEED
Expand Down
4 changes: 2 additions & 2 deletions request_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CompletionRequestBody(BaseModel):
# llama.cpp specific parameters
top_k: Optional[int]
repetition_penalty: Optional[float]
frequency_penalty: Optional[float]
last_n_tokens: Optional[int]
seed: Optional[int]
batch_size: Optional[int]
Expand All @@ -32,7 +33,6 @@ class CompletionRequestBody(BaseModel):
# ignored or currently unsupported
suffix: Any
presence_penalty: Any
frequency_penalty: Any
echo: Any
n: Any
logprobs: Any
Expand Down Expand Up @@ -73,6 +73,7 @@ class ChatCompletionRequestBody(BaseModel):
# llama.cpp specific parameters
top_k: Optional[int]
repetition_penalty: Optional[float]
frequency_penalty: Optional[float]
last_n_tokens: Optional[int]
seed: Optional[int]
batch_size: Optional[int]
Expand All @@ -83,7 +84,6 @@ class ChatCompletionRequestBody(BaseModel):
logit_bias: Any
user: Any
presence_penalty: Any
frequency_penalty: Any

class Config:
arbitrary_types_allowed = True

0 comments on commit 774c24d

Please sign in to comment.