Skip to content

Commit

Permalink
Trainer ops (#112)
Browse files Browse the repository at this point in the history
* fix: model service

* feat trainer cookbook
---------

Co-authored-by: zhonghanjun <zhonghanjun@baidu.com>
  • Loading branch information
Dobiichi-Origami and danielhjz authored Nov 30, 2023
1 parent 144a2ca commit cfbe526
Show file tree
Hide file tree
Showing 11 changed files with 496 additions and 71 deletions.
41 changes: 35 additions & 6 deletions cookbook/RAG/baidu_elasticsearch/qianfan_baidu_elasticsearch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@
"本文主要介绍基于Langchain的框架,结合BES的向量数据库的能力,对接千帆平台的模型管理和应用接入的能力,从而构建一个RAG的知识问答场景。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bes,1\n",
"faiss,1\n",
"chroma, 1\n",
"\n",
"postgresql\n",
"milvus, \n",
"pincone2"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -73,7 +88,7 @@
"\n",
"loader = TextLoader(\"example_data/ai-paper.pdf\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=768, chunk_overlap=0, separators=[\"\\n\\n\", \"\\n\", \" \", \"\", \"\", \"\"])\n",
"text_splitter = CharacterTextSplitter(chunk_size=768, chunk_overlap=0, separators=[\"\\n\\n\", \"\\n\", \" \", \"\", \"\", \"\"]) # spaciy\n",
"docs = text_splitter.split_documents(documents)"
]
},
Expand All @@ -92,9 +107,13 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import QianfanEmbeddingsEndpoint\n",
"from langchain.embeddings import QianfanEmbeddingsEndpoint #sdk\n",
"\n",
"embeddings = QianfanEmbeddingsEndpoint()"
"embeddings = QianfanEmbeddingsEndpoint()\n",
"# embeddings-v1\n",
"# bge-large-zh\n",
"# 12月2k\n",
"# "
]
},
{
Expand Down Expand Up @@ -123,6 +142,15 @@
"bes.client.indices.refresh(index=\"your vector index\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Faiss\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -141,6 +169,7 @@
"from langchain.chat_models import QianfanChatEndpoint\n",
"\n",
"qianfan_chat_model = QianfanChatEndpoint(model=\"ERNIE-Bot\")\n",
"# sdk prompt load from qianfan\n",
"qa = RetrievalQA.from_chain_type(llm=llm, chain_type=\"refine\", retriever=retriever, return_source_documents=True)\n",
"\n",
"\n",
Expand Down Expand Up @@ -171,18 +200,18 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.17"
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
"hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions cookbook/eb_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "py311",
"display_name": "base",
"language": "python",
"name": "python3"
},
Expand All @@ -223,7 +223,7 @@
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "f553a591cb5da27fa30e85168a93942a1a24c8d6748197473adb125e5473a5db"
"hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
}
}
},
Expand Down
268 changes: 268 additions & 0 deletions cookbook/finetune/trainer-finetune.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/qianfan/dataset/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,10 @@ def release_dataset(self) -> bool:
Returns:
bool: Whether releasing succeeded
"""
info = Data.get_dataset_info(self.id)["result"]["versionInfo"]
status = info["releaseStatus"]
if status == DataReleaseStatus.Finished:
return True
Data.release_dataset(self.id)
while True:
sleep(get_config().RELEASE_STATUS_POLLING_INTERVAL)
Expand Down
8 changes: 7 additions & 1 deletion src/qianfan/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def save(
qianfan_dataset_id: Optional[int] = None,
qianfan_dataset_create_args: Optional[Dict[str, Any]] = None,
schema: Optional[Schema] = None,
replace_source: bool = False,
**kwargs: Any,
) -> bool:
"""
Expand All @@ -379,6 +380,8 @@ def save(
default to None
schema: (Optional[Schema]):
schema used to validate before exporting data, default to None
replace_source: (bool):
if replace the original source, default to False
kwargs (Any): optional arguments
Returns:
Expand Down Expand Up @@ -416,7 +419,10 @@ def save(
kwargs["is_annotated"] = schema.is_annotated

# 开始写入数据
return self._to_source(source, **kwargs) # noqa
res = self._to_source(source, **kwargs) # noqa
if res and replace_source:
self.inner_data_source_cache = source
return res

@classmethod
def create_from_pyobj(
Expand Down
9 changes: 8 additions & 1 deletion src/qianfan/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from qianfan.trainer.consts import ServiceType
from qianfan.trainer.event import Event, EventHandler
from qianfan.trainer.finetune import LLMFinetune
from qianfan.trainer.model import Model
from qianfan.trainer.model import Model, Service


class MyEventHandler(EventHandler):
Expand Down Expand Up @@ -148,3 +148,10 @@ def test_model_deploy():

resp = svc.exec({"messages": [{"content": "hi", "role": "user"}]})
assert resp["result"] != ""


def test_service():
svc = Service(model="ERNIE-Bot", service_type=ServiceType.Chat)
resp = svc.exec({"messages": [{"content": "hi", "role": "user"}]})
assert resp is not None
assert resp["result"] != ""
13 changes: 5 additions & 8 deletions src/qianfan/trainer/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from qianfan.resources.console import consts as console_consts
from qianfan.trainer.base import (
BaseAction,
EventHandler,
with_event,
)
from qianfan.trainer.configs import DefaultTrainConfigMapping, DeployConfig, TrainConfig
Expand Down Expand Up @@ -55,10 +54,9 @@ class LoadDataSetAction(BaseAction[Dict[str, Any], Dict[str, Any]]):
def __init__(
self,
dataset: Optional[Dataset] = None,
event_handler: Optional[EventHandler] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
super().__init__(event_handler=event_handler)
super().__init__(**kwargs)
self.dataset = dataset

@with_event
Expand All @@ -73,6 +71,7 @@ def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]:
)
log_debug("[load_dataset_action] prepare train-set")
qf_data_src = cast(QianfanDataSource, self.dataset.inner_data_source_cache)
print("==>")
is_released = qf_data_src.release_dataset()
if not is_released:
raise InvalidArgumentError("dataset must be released")
Expand Down Expand Up @@ -365,16 +364,14 @@ class DeployAction(BaseAction[Dict[str, Any], Dict[str, Any]]):
model_id: Optional[int]
model_version_id: Optional[int]

def __init__(
self, deploy_config: Optional[DeployConfig] = None, **kwargs: Dict[str, Any]
):
def __init__(self, deploy_config: Optional[DeployConfig] = None, **kwargs: Any):
"""
Parameters:
deploy_config (Optional[DeployConfig], optional):
deploy config include replicas and so on. Defaults to None.
"""
super().__init__(kwargs=kwargs)
super().__init__(**kwargs)
self.deploy_config = deploy_config

@with_event
Expand Down
39 changes: 37 additions & 2 deletions src/qianfan/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(
event_handler (Optional[EventHandler], optional):
event_handler implements for action state track. Defaults to None.
"""
self.id = id if id is not None else utils.uuid()
self.name = name if name is not None else f"actions_{self.id}"
self.id = id if id is not None else utils.generate_letter_num_random_id()
self.name = name if name is not None else f"action_{self.id}"
self.state = ActionState.Preceding
self.event_dispatcher = event_handler

Expand Down Expand Up @@ -192,6 +192,10 @@ def action_event(self, state: ActionState, msg: str = "", data: Any = None) -> N
),
)

@classmethod
def action_type(cls) -> str:
return "base"


def with_event(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Expand Down Expand Up @@ -329,6 +333,22 @@ def stop(self, **kwargs: Dict) -> None:

return super().stop()

def register_event_handler(
self, event_handler: EventHandler, action_id: Optional[str] = None
) -> None:
"""
Register the event handler to specific the action.
Args:
event_handler (EventHandler): The event handler instance.
"""
self.event_dispatcher = event_handler
for id, action in self.actions.items():
if action_id is None and id == action_id:
action.event_dispatcher == event_handler
break
else:
action.event_dispatcher = event_handler


class Trainer(ABC):
"""
Expand Down Expand Up @@ -397,3 +417,18 @@ def get_log(self) -> Any:
Receive the training log during the pipeline execution. [coming soon].
"""
raise NotImplementedError("trainer get_log")

def register_event_handler(
self, event_handler: EventHandler, ppl_id: Optional[str] = None
) -> None:
"""
Register the event handler to specific the ppls.
Args:
event_handler (EventHandler): The event handler instance.
"""
for ppl in self.ppls:
if ppl_id is None and ppl.id == ppl_id:
ppl.register_event_handler(event_handler)
break
else:
ppl.register_event_handler(event_handler)
22 changes: 21 additions & 1 deletion src/qianfan/trainer/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, Dict

from qianfan.resources import ChatCompletion, Completion, Embedding, Text2Image


class ActionState(str, Enum):
Expand All @@ -36,7 +39,15 @@ class ActionState(str, Enum):
class FinetuneStatus(str, Enum):
Unknown = "Unknown"
"""未知状态"""
Created = "Created"
DatasetLoading = "DatasetLoading"
"""数据集加载中"""
DatasetLoaded = "DatasetLoaded"
"""数据集加载完成"""
DatasetLoadFailed = "DatasetLoadFailed"
"""数据集加载失败"""
DatasetLoadStopped = "DatasetLoadStopped"
"""数据集停止加载"""
TrainCreated = "TrainCreated"
"""任务创建,初始化"""
Training = "Training"
"""训练中 对应训练任务运行时API状态的Running"""
Expand Down Expand Up @@ -92,3 +103,12 @@ class ServiceType(str, Enum):
"""Corresponding to the `Embedding`"""
Text2Image = "Text2Image"
"""Corresponding to the `Text2Image"""


# service type -> resources class
ServiceTypeResourcesMapping: Dict[ServiceType, Any] = {
ServiceType.Chat: ChatCompletion,
ServiceType.Completion: Completion,
ServiceType.Embedding: Embedding,
ServiceType.Text2Image: Text2Image,
}
13 changes: 10 additions & 3 deletions src/qianfan/trainer/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
event_handler=event_handler,
**kwargs,
)
self.model_publish = ModelPublishAction()
self.model_publish = ModelPublishAction(event_handler=event_handler)

actions = [
self.load_data_action,
Expand All @@ -131,7 +131,7 @@ def __init__(
if deploy_config is not None:
self.deploy_action = DeployAction(
deploy_config=deploy_config,
**{"event_handler": event_handler, **kwargs},
event_handler=event_handler,
)
actions.append(self.deploy_action)

Expand Down Expand Up @@ -208,8 +208,15 @@ def resume(self, **kwargs: Dict) -> "LLMFinetune":

# mapping for action state -> fine-tune status
fine_tune_action_mapping: Dict[str, Dict[str, Any]] = {
LoadDataSetAction.__class__.__name__: {
ActionState.Preceding: FinetuneStatus.DatasetLoading,
ActionState.Running: FinetuneStatus.DatasetLoading,
ActionState.Done: FinetuneStatus.DatasetLoaded,
ActionState.Error: FinetuneStatus.DatasetLoadFailed,
ActionState.Stopped: FinetuneStatus.DatasetLoadStopped,
},
TrainAction.__class__.__name__: {
ActionState.Preceding: FinetuneStatus.Created,
ActionState.Preceding: FinetuneStatus.TrainCreated,
ActionState.Running: FinetuneStatus.Training,
ActionState.Done: FinetuneStatus.TrainFinished,
ActionState.Error: FinetuneStatus.TrainFailed,
Expand Down
Loading

0 comments on commit cfbe526

Please sign in to comment.