Skip to content

Commit

Permalink
feat: support ernie-lite-v (#703)
Browse files Browse the repository at this point in the history
* feat: support ernie-lite-v

* feat: support Model API prefix consts config

* fix: docs release

* fix: modesl list patch
  • Loading branch information
danielhjz authored Aug 3, 2024
1 parent 4d12fe9 commit 6245d61
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/doc_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Install
run: |
pip3 install poetry
poetry lock --no-update
cd python && poetry lock --no-update
make install
- name: Build docs
run: |
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qianfan"
version = "0.4.3"
version = "0.4.4"
description = "文心千帆大模型平台 Python SDK"
authors = []
license = "Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions python/qianfan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Config:
SECRET_KEY: Optional[str] = Field(default=None)
ACCESS_TOKEN: Optional[str] = Field(default=None)
BASE_URL: str = Field(default=DefaultValue.BaseURL)
MODEL_API_PREFIX: str = Field(default=DefaultValue.ModelAPIPrefix)
AUTH_TIMEOUT: float = Field(default=DefaultValue.AuthTimeout)
DISABLE_EB_SDK: bool = Field(default=DefaultValue.DisableErnieBotSDK)
EB_SDK_INSTALLED: bool = Field(default=False)
Expand Down
153 changes: 77 additions & 76 deletions python/qianfan/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,82 +105,6 @@ class Env:
FileEncoding: str = "QIANFAN_FILE_ENCODING"


class DefaultValue:
"""
Default value used by qianfan sdk
"""

AK: str = ""
SK: str = ""
ConsoleAK: str = ""
ConsoleSK: str = ""
AccessToken: str = ""
BaseURL: str = "https://aip.baidubce.com"
AuthTimeout: float = 5
DisableErnieBotSDK: bool = True
IAMSignExpirationSeconds: int = 300
ConsoleAPIBaseURL: str = "https://qianfan.baidubce.com"
AccessTokenRefreshMinInterval: float = 3600
InferResourceRefreshMinInterval: float = 600
RetryCount: int = 3
RetryTimeout: float = 300
RetryBackoffFactor: float = 1
RetryJitter: float = 1
RetryMaxWaitInterval: float = 120
ConsoleRetryCount: int = 1
ConsoleRetryTimeout: float = 60
ConsoleRetryBackoffFactor: float = 0
ConsoleRetryJitter: int = 1
ConsoleRetryMaxWaitInterval: float = 120
ConsoleRetryErrCodes: Set = {
APIErrorCode.ServerHighLoad.value,
APIErrorCode.QPSLimitReached.value,
APIErrorCode.ConsoleInternalError.value,
}
QpsLimit: float = 0
RpmLimit: float = 0
TpmLimit: int = 0
DotEnvConfigFile: str = ".env"

EnablePrivate: bool = False
AccessCode: str = ""
TruncatedContinuePrompt = "继续"
ImportStatusPollingInterval: float = 2
ExportStatusPollingInterval: float = 2
ReleaseStatusPollingInterval: float = 2
ETLStatusPollingInterval: float = 2
TrainStatusPollingInterval: float = 30
TrainerStatusPollingBackoffFactor: float = 3
TrainerStatusPollingRetryTimes: float = 3
ModelPublishStatusPollingInterval: float = 30
BatchRunStatusPollingInterval: float = 30
DeployStatusPollingInterval: float = 30
DefaultFinetuneTrainType: str = "ERNIE-Speed"
V2InferApiDowngrade: bool = False

# 目前可直接下载到本地的千帆数据集解压后的大小上限
# 后期研究更换为用户机内存大小的上限
# 目前限制 2GB,防止用户内存爆炸
ExportFileSizeLimit: int = 1024 * 1024 * 1024 * 2
GetEntityContentFailedRetryTimes: int = 3

EvaluationOnlinePollingInterval: float = 30
BosHostRegion: str = "bj"
RetryErrCodes: Set = {
APIErrorCode.ServiceUnavailable.value,
APIErrorCode.ServerHighLoad.value,
APIErrorCode.QPSLimitReached.value,
APIErrorCode.RPMLimitReached.value,
APIErrorCode.TPMLimitReached.value,
APIErrorCode.AppNotExist.value,
}
SSLVerificationEnabled: bool = True
Proxy: str = ""
FileEncoding: str = "utf-8"
CacheDir: str = str(Path.home() / ".qianfan_cache")
DisableCache: bool = False


class Consts:
"""
Constant used by qianfan sdk
Expand Down Expand Up @@ -330,6 +254,83 @@ class Consts:
DateTimeFormat = "%Y-%m-%dT%H:%M:%SZ"


class DefaultValue:
"""
Default value used by qianfan sdk
"""

AK: str = ""
SK: str = ""
ConsoleAK: str = ""
ConsoleSK: str = ""
AccessToken: str = ""
BaseURL: str = "https://aip.baidubce.com"
ModelAPIPrefix: str = Consts.ModelAPIPrefix
AuthTimeout: float = 5
DisableErnieBotSDK: bool = True
IAMSignExpirationSeconds: int = 300
ConsoleAPIBaseURL: str = "https://qianfan.baidubce.com"
AccessTokenRefreshMinInterval: float = 3600
InferResourceRefreshMinInterval: float = 600
RetryCount: int = 3
RetryTimeout: float = 300
RetryBackoffFactor: float = 1
RetryJitter: float = 1
RetryMaxWaitInterval: float = 120
ConsoleRetryCount: int = 1
ConsoleRetryTimeout: float = 60
ConsoleRetryBackoffFactor: float = 0
ConsoleRetryJitter: int = 1
ConsoleRetryMaxWaitInterval: float = 120
ConsoleRetryErrCodes: Set = {
APIErrorCode.ServerHighLoad.value,
APIErrorCode.QPSLimitReached.value,
APIErrorCode.ConsoleInternalError.value,
}
QpsLimit: float = 0
RpmLimit: float = 0
TpmLimit: int = 0
DotEnvConfigFile: str = ".env"

EnablePrivate: bool = False
AccessCode: str = ""
TruncatedContinuePrompt = "继续"
ImportStatusPollingInterval: float = 2
ExportStatusPollingInterval: float = 2
ReleaseStatusPollingInterval: float = 2
ETLStatusPollingInterval: float = 2
TrainStatusPollingInterval: float = 30
TrainerStatusPollingBackoffFactor: float = 3
TrainerStatusPollingRetryTimes: float = 3
ModelPublishStatusPollingInterval: float = 30
BatchRunStatusPollingInterval: float = 30
DeployStatusPollingInterval: float = 30
DefaultFinetuneTrainType: str = "ERNIE-Speed"
V2InferApiDowngrade: bool = False

# 目前可直接下载到本地的千帆数据集解压后的大小上限
# 后期研究更换为用户机内存大小的上限
# 目前限制 2GB,防止用户内存爆炸
ExportFileSizeLimit: int = 1024 * 1024 * 1024 * 2
GetEntityContentFailedRetryTimes: int = 3

EvaluationOnlinePollingInterval: float = 30
BosHostRegion: str = "bj"
RetryErrCodes: Set = {
APIErrorCode.ServiceUnavailable.value,
APIErrorCode.ServerHighLoad.value,
APIErrorCode.QPSLimitReached.value,
APIErrorCode.RPMLimitReached.value,
APIErrorCode.TPMLimitReached.value,
APIErrorCode.AppNotExist.value,
}
SSLVerificationEnabled: bool = True
Proxy: str = ""
FileEncoding: str = "utf-8"
CacheDir: str = str(Path.home() / ".qianfan_cache")
DisableCache: bool = False


class DefaultLLMModel:
"""
Defualt LLM model in qianfan sdk
Expand Down
35 changes: 19 additions & 16 deletions python/qianfan/resources/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,22 +1022,25 @@ def get_latest_supported_models(

# get preset services:
for s in svc_list:
[api_type, model_endpoint] = trim_prefix(
s["url"],
"{}{}/".format(
DefaultValue.BaseURL,
Consts.ModelAPIPrefix,
),
).split("/")
model_info = _runtime_models_info.get(api_type)
if model_info is None:
model_info = {}
model_info[s["name"]] = QfLLMInfo(
endpoint="/{}/{}".format(api_type, model_endpoint),
api_type=api_type,
)
_runtime_models_info[api_type] = model_info
_last_update_time = datetime.now(timezone.utc)
try:
[api_type, model_endpoint] = trim_prefix(
s["url"],
"{}{}/".format(
DefaultValue.BaseURL,
get_config().MODEL_API_PREFIX,
),
).split("/")
model_info = _runtime_models_info.get(api_type)
if model_info is None:
model_info = {}
model_info[s["name"]] = QfLLMInfo(
endpoint="/{}/{}".format(api_type, model_endpoint),
api_type=api_type,
)
_runtime_models_info[api_type] = model_info
_last_update_time = datetime.now(timezone.utc)
except Exception:
continue
cache = KvCache()
cache.set(
key=Consts.QianfanLLMModelsListCacheKey,
Expand Down
72 changes: 72 additions & 0 deletions python/qianfan/resources/llm/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,32 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
input_price_per_1k_tokens=0.004,
output_price_per_1k_tokens=0.008,
),
"ERNIE-3.5-8K-0701": QfLLMInfo(
endpoint="/chat/ernie-3.5-8k-0701",
required_keys={"messages"},
optional_keys={
"stream",
"temperature",
"top_p",
"penalty_score",
"user_id",
"system",
"stop",
"enable_system_memory",
"system_memory_id",
"disable_search",
"enable_citation",
"enable_trace",
"max_output_tokens",
"response_format",
"functions",
"tool_choice",
},
max_input_chars=20000,
max_input_tokens=5120,
input_price_per_1k_tokens=0.12,
output_price_per_1k_tokens=0.12,
),
"ERNIE-3.5-8K-0613": QfLLMInfo(
endpoint="/chat/ernie-3.5-8k-0613",
required_keys={"messages"},
Expand Down Expand Up @@ -218,6 +244,27 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
input_price_per_1k_tokens=0.003,
output_price_per_1k_tokens=0.006,
),
"ERNIE-Lite-V": QfLLMInfo(
endpoint="/chat/ernie-lite-v",
required_keys={"messages"},
optional_keys={
"stream",
"temperature",
"top_p",
"penalty_score",
"user_id",
"system",
"stop",
"max_output_tokens",
"min_output_tokens",
"frequency_penalty",
"presence_penalty",
},
max_input_chars=11200,
max_input_tokens=7168,
input_price_per_1k_tokens=0.003,
output_price_per_1k_tokens=0.006,
),
"ERNIE-3.5-8K": QfLLMInfo(
endpoint="/chat/completions",
required_keys={"messages"},
Expand Down Expand Up @@ -564,6 +611,31 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
input_price_per_1k_tokens=0.001,
output_price_per_1k_tokens=0.001,
),
"ERNIE-Novel-8K": QfLLMInfo(
endpoint="/chat/ernie-novel-8k",
required_keys={"messages"},
optional_keys={
"stream",
"temperature",
"top_p",
"penalty_score",
"user_id",
"tools",
"tool_choice",
"system",
"stop",
"enable_system_memory",
"system_memory_id",
"max_output_tokens",
"min_output_tokens",
"frequency_penalty",
"presence_penalty",
},
max_input_chars=24000,
max_input_tokens=6144,
input_price_per_1k_tokens=0.001,
output_price_per_1k_tokens=0.001,
),
"ERNIE-Function-8K": QfLLMInfo(
endpoint="/chat/ernie-func-8k",
required_keys={"messages"},
Expand Down
14 changes: 10 additions & 4 deletions python/qianfan/resources/requestor/openapi_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Callable,
Dict,
Iterator,
List,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -314,8 +315,13 @@ def _get_token_count_from_body(self, body: Dict[str, Any]) -> int:
content = message.get("content", None)
if not content:
continue

token_count += self._token_limiter.tokenizer.count_tokens(content)
if isinstance(content, str):
token_count += self._token_limiter.tokenizer.count_tokens(content)
elif isinstance(content, List):
for ct in content:
token_count += self._token_limiter.tokenizer.count_tokens(
ct.get("text", "")
)

if prompt:
assert isinstance(prompt, str)
Expand Down Expand Up @@ -567,7 +573,7 @@ def _llm_api_url(self, endpoint: str) -> str:
"""
return "{}{}{}".format(
get_config().BASE_URL,
Consts.ModelAPIPrefix,
get_config().MODEL_API_PREFIX,
endpoint,
)

Expand Down Expand Up @@ -699,7 +705,7 @@ def _base_llm_request(
req = QfRequest(
method="POST",
url="{}{}".format(
Consts.ModelAPIPrefix,
get_config().MODEL_API_PREFIX,
endpoint,
),
)
Expand Down

0 comments on commit 6245d61

Please sign in to comment.