Skip to content

Commit

Permalink
fix: docs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhjz committed Feb 2, 2024
1 parent d9ce152 commit 319a93f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
53 changes: 48 additions & 5 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
![trainer](./imgs/trainer.png)
## 快速开始

### Finetune
以下以LLMFinetune(对应千帆平台 SFT语言大模型)为例,介绍如何使用`Trainer`进行训练。

```python
Expand All @@ -23,14 +24,56 @@ ds: Dataset = Dataset.load(qianfan_dataset_id=111, is_download_to_local=False)
# 新建trainer LLMFinetune,最少传入train_type和dataset
# 注意fine-tune任务需要指定的数据集类型要求为有标注的非排序对话数据集。
trainer = LLMFinetune(
train_type="ERNIE-Bot-turbo-0725",
train_type="ERNIE-Speed",
dataset=ds,
)

trainer.run()
```

## 自定义训练参数
### PostPretrain
除了使用`LLMFinetune`进行模型微调外,我们还可以使用`PostPretrain`:

```python
from qianfan.trainer import PostPreTrain, LLMFinetune
from qianfan.trainer.configs import TrainConfig
from qianfan.trainer.consts import PeftType
from qianfan.dataset import Dataset

# 泛文本 数据集
ds = Dataset.load(qianfan_dataset_id="ds-ag138", is_download_to_local=False)

# postpretrain
trainer = PostPreTrain(
train_type="ERNIE-Speed",
dataset=ds,
)
trainer.run()
# 这一步可以拿到训练完成的PostPretrain任务信息:
print(trainer.output)


# sft数据集
sft_ds = Dataset.load(qianfan_dataset_id="ds-47j7ztjxfz60wb8x", is_download_to_local=False)
ppt_sft_trainer = LLMFinetune(
train_type="ERNIE-Speed",
dataset=sft_ds,
train_config=TrainConfig(
epoch=1,
learning_rate=0.00003,
max_seq_len=4096,
peft_type=PeftType.ALL,
),
name="qianfantrainer01"
previous_trainer=trainer,
)

ppt_sft_trainer.run()
# 拿到最终的可用于推理部署的模型:
print(ppt_sft_trainer.output)
```

### 自定义训练参数
如果需要自定义训练参数,可以根据不同的模型传入不同的TrainConfig 以指定训练过程中的参数,需要注意的是不同模型支持的参数不同,具体以API文档为准。
```python
import os
Expand All @@ -43,7 +86,7 @@ from qianfan.trainer import LLMFinetune
from qianfan.trainer.configs import TrainConfig

trainer = LLMFinetune(
train_type="ERNIE-Bot-turbo-0516",
train_type="ERNIE-Speed",
dataset=ds,
train_config=TrainConfig(
epochs=1, # 迭代轮次(Epoch),控制训练过程中的迭代轮数。
Expand All @@ -54,7 +97,7 @@ trainer = LLMFinetune(
trainer.run()
```

## 事件回调
### 事件回调

如果需要在训练过程中监控每个阶段的各个节点的状态,可以通过事件回调函数来实现

Expand All @@ -80,7 +123,7 @@ class MyEventHandler(EventHandler):

eh = MyEventHandler()
trainer = LLMFinetune(
train_type="Llama-2-13b",
train_type="ERNIE-Speed",
dataset=ds,
train_config=TrainConfig(
epochs=1,
Expand Down
2 changes: 2 additions & 0 deletions python/qianfan/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from qianfan.trainer.event import Event, EventHandler
from qianfan.trainer.finetune import LLMFinetune, Trainer
from qianfan.trainer.post_pretrain import PostPreTrain

__all__ = [
"LLMFinetune",
Expand All @@ -31,4 +32,5 @@
"LoadDataSetAction",
"DeployAction",
"ModelPublishAction",
"PostPreTrain",
]

0 comments on commit 319a93f

Please sign in to comment.