使用类Transformer模型,对每条对话语句进行分类。
.
├── README.md
├── data/
├── native-pytorch
│ ├── data_loader.py
│ ├── finetune.py
│ ├── test.py
│ ├── train_from_scratch.py
│ └── utils.py
└── use_trainer
├── finetune_use_trainer.py
├── my_dataset.py
├── predict.py
└── trainer_utils.py
data
:包含了KdConv与NaturalConv混合的共6个领域的数据集,其中data.json
的数据信息如下:
体育 | 科技 | 教育 | 音乐 | 旅行 | 电影 | |
---|---|---|---|---|---|---|
# dialogues | 9740 | 4061 | 1265 | 1500 | 1500 | 1500 |
# utterances | 195643 | 81587 | 25376 | 24885 | 24093 | 36618 |
Avg. # utterances per dialogue | 20.1 | 20.1 | 20.1 | 16.6 | 16.1 | 24.4 |
native-pytorch
:使用原生PyTorch训练分类器data_loader.py
:包含dataset与dataloader的实现utils.py
:包含固定随机种子等工具函数finetune.py
:使用原生PyTorch微调预训练模型train_from_scratch.py
:使用Accelerate库从头训练模型test.py
:测试模型准确率、分辨率等指标
use_trainer
:使用🤗Transformers库Trainer训练分类器my_dataset.py
:包含基于🤗Dataset的dataset,使用了除“教育”领域的5类数据finetune_use_trainer.py
:使用trainer微调预训练模型predict.py
:使用trainer测试模型准确率、分辨率等指标trainer_utils.py
:包含一些工具函数