Skip to content

Commit

Permalink
feat[trainer]: optimize train_config (#135)
Browse files Browse the repository at this point in the history
* fix: train_config

* fix: ut

* chore: add more logs for training

* fix: lint

* fix: pefttype
  • Loading branch information
danielhjz authored Dec 14, 2023
1 parent aea6cb5 commit 7f8d9c0
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 51 deletions.
3 changes: 2 additions & 1 deletion cookbook/finetune/trainer_finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,15 @@
"metadata": {},
"outputs": [],
"source": [
"from qianfan.trainer.consts import PeftType\n",
"\n",
"trainer = LLMFinetune(\n",
" train_type=\"ERNIE-Bot-turbo-0725\",\n",
" train_config=TrainConfig(\n",
" epoch=1,\n",
" learning_rate=0.0003,\n",
" max_seq_len=4096,\n",
" peft_type=\"LoRA\",\n",
" peft_type=PeftType.LoRA,\n",
" ),\n",
" dataset=ds,\n",
")"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qianfan"
version = "0.2.2"
version = "0.2.3"
description = "文心千帆大模型平台 Python SDK"
authors = []
license = "Apache-2.0"
Expand Down
3 changes: 2 additions & 1 deletion src/qianfan/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TrainAction,
)
from qianfan.trainer.configs import DeployConfig, TrainConfig
from qianfan.trainer.consts import ServiceType
from qianfan.trainer.consts import PeftType, ServiceType
from qianfan.trainer.event import Event, EventHandler
from qianfan.trainer.finetune import LLMFinetune
from qianfan.trainer.model import Model, Service
Expand Down Expand Up @@ -86,6 +86,7 @@ def test_trainer_sft_run():
learning_rate=0.00002,
max_seq_len=4096,
trainset_rate=20,
peft_type=PeftType.ALL,
)
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", console_consts.DataTemplateType.NonSortedConversation
Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/tests/utils/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def get_finetune_job():
"trainMode": "SFT",
"peftType": "LoRA",
"trainStatus": "FINISH",
"progress": 0,
"progress": 51,
"runTime": 2525,
"trainTime": 732,
"startTime": "2023-12-07 11:40:00",
Expand Down
138 changes: 129 additions & 9 deletions src/qianfan/trainer/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
BaseAction,
with_event,
)
from qianfan.trainer.configs import DefaultTrainConfigMapping, DeployConfig, TrainConfig
from qianfan.trainer.consts import (
ModelTypeMapping,
from qianfan.trainer.configs import (
DefaultTrainConfigMapping,
DeployConfig,
ModelInfoMapping,
TrainConfig,
TrainLimit,
)
from qianfan.trainer.model import Model
from qianfan.utils import (
Expand Down Expand Up @@ -220,16 +223,22 @@ def __init__(
raise InvalidArgumentError("train_type must be specified")
# train from base model
self.train_type = train_type
self.base_model = (
ModelTypeMapping.get(self.train_type)
if base_model is None
else base_model
)
if base_model is None:
model_info = ModelInfoMapping.get(self.train_type)
if model_info is None:
raise InvalidArgumentError(
"base_model_type must be specified caused train_type:"
f" {self.train_type} is not found"
)
self.base_model = model_info.base_model_type
else:
self.base_model = base_model
self.train_config = (
train_config
if train_config is not None
else self.get_default_train_config(train_type)
)
self.validateTrainConfig()
if train_mode is not None:
self.train_mode = train_mode
self.task_name = (
Expand All @@ -240,6 +249,112 @@ def __init__(
self.task_description = task_description
self.job_description = job_description

def validateTrainConfig(self) -> None:
"""
validate train_config with ModelInfo Limits
Raises:
InvalidArgumentError: _description_
"""
if self.train_config is None:
raise InvalidArgumentError("none train_config")
if self.train_type not in ModelInfoMapping:
log_warn(
f"[train_action] train_type {self.train_type} not found, it may be not"
" supported"
)
else:
train_type_model_info = ModelInfoMapping[self.train_type]
if (
self.train_config.peft_type
not in train_type_model_info.support_peft_types
):
log_warn(
f"[train_action] train_type {self.train_type}, peft_type"
f" {self.train_config.peft_type} not found, it may be not supported"
)
else:
if (
train_type_model_info.specific_peft_types_params_limit is not None
and self.train_config.peft_type
in train_type_model_info.specific_peft_types_params_limit
):
self._validate_train_config(
train_type_model_info.specific_peft_types_params_limit[
self.train_config.peft_type
],
)
else:
self._validate_train_config(
train_type_model_info.common_params_limit
)

def _validate_train_config(self, train_limit: TrainLimit) -> None:
"""
validate train_config with a specific train_limit
Args:
train_limit (TrainLimit): _description_
Raises:
InvalidArgumentError: _description_
"""
if self.train_config is None:
raise InvalidArgumentError("validate train_config is none")
if (
self.train_config.batch_size
and train_limit.batch_size_limit
and not (
train_limit.batch_size_limit[0]
<= self.train_config.batch_size
<= self.train_config.batch_size
> train_limit.batch_size_limit[1]
)
):
log_warn(
f"[train_action] current batch_size: {self.train_config.batch_size},"
f" but suggested batch size in [{train_limit.batch_size_limit[0]},"
f" {train_limit.batch_size_limit[1]}]"
)
if (
self.train_config.epoch
and train_limit.epoch_limit
and not (
train_limit.epoch_limit[0]
<= self.train_config.epoch
<= train_limit.epoch_limit[1]
)
):
log_warn(
f"[train_action] current epoch: {self.train_config.epoch}, but"
f" suggested epoch in [{train_limit.epoch_limit[0]},"
f" {train_limit.epoch_limit[1]}]"
)
if (
self.train_config.max_seq_len
and train_limit.max_seq_len_options
and self.train_config.max_seq_len not in train_limit.max_seq_len_options
):
log_warn(
f"[train_action] current max_seq_len: {self.train_config.max_seq_len},"
f" but supported max_seq_len may be [{train_limit.max_seq_len_options}]"
)
if (
self.train_config.learning_rate
and train_limit.learning_rate_limit
and not (
train_limit.learning_rate_limit[0]
<= self.train_config.learning_rate
<= train_limit.learning_rate_limit[1]
)
):
log_warn(
"[train_action] current learning rate:"
f" {self.train_config.learning_rate}, but suggested learning rate in"
f" [{train_limit.learning_rate_limit[0]},"
f" {train_limit.learning_rate_limit[1]}]"
)

def _exec_incremental(
self, input: Dict[str, Any], **kwargs: Dict
) -> Dict[str, Any]:
Expand Down Expand Up @@ -349,10 +464,15 @@ def _wait_model_trained(self, **kwargs: Dict) -> None:
**kwargs,
)
job_status = job_status_resp["result"]["trainStatus"]
job_progress = job_status_resp["result"]["progress"]
log_info(
f"[train_action] fine-tune running... current status: {job_status},"
f" check vdl report in {job_status_resp['result']['vdlLink']}"
f" {job_progress}% "
"check train log in"
f" https://console.bce.baidu.com/qianfan/train/sft/{self.task_id}/{self.job_id}/detail/traininglog"
)
if job_progress >= 50:
log_info(f" check vdl report in {job_status_resp['result']['vdlLink']}")
self.action_event(ActionState.Running, "train running", job_status_resp)
if job_status == console_consts.TrainStatus.Finish:
break
Expand Down
Loading

0 comments on commit 7f8d9c0

Please sign in to comment.