Skip to content

Latest commit

 

History

History
 
 

clue

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

CLUE Benchmark

目录

CLUE 自成立以来发布了多项 NLP 评测基准,包括分类榜单,阅读理解榜单和自然语言推断榜单等,在学术界、工业界产生了深远影响。是目前应用最广泛的中文语言测评指标之一。详细可参考 CLUE论文

本项目基于 PaddlePaddle 在 CLUE 数据集上对领先的开源预训练模型模型进行了充分评测,为开发者在预训练模型选择上提供参考,同时开发者基于本项目可以轻松一键复现模型效果,也可以参加 CLUE 竞赛取得好成绩。

CLUE 评测结果

使用多种中文预训练模型微调在 CLUE 的各验证集上有如下结果:

Arch Model AVG AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL CMRC2018 CHID C3
24L1024H ERNIE 1.0-Large-zh-cw 79.03 75.97 59.65 62.91 85.09 81.73 93.09 84.53 74.22/91.88 88.57 84.54
ERNIE 2.0-Large-zh 77.03 76.41 59.67 62.29 83.82 79.69 89.14 84.10 71.48/90.35 85.52 78.12
HFL/RoBERTa-wwm-ext-large 76.61 76.00 59.33 62.02 83.88 78.81 90.79 83.67 70.58/89.82 85.72 75.26
20L1024H ERNIE 3.0-Xbase-zh 78.39 76.16 59.55 61.87 84.40 81.73 88.82 83.60 75.99/93.00 86.78 84.98
12L768H ERNIE 3.0-Base-zh 76.05 75.93 58.26 61.56 83.02 80.10 86.18 82.63 70.71/90.41 84.26 77.88
ERNIE 1.0-Base-zh-cw 76.47 76.07 57.86 59.91 83.41 79.58 89.91 83.42 72.88/90.78 84.68 76.98
ERNIE-Gram-zh 75.72 75.28 57.88 60.87 82.90 79.08 88.82 82.83 71.82/90.38 84.04 73.69
ERNIE 2.0-Base-zh 74.32 75.65 58.25 61.64 82.62 78.71 81.91 82.33 66.08/87.46 82.78 73.19
Langboat/Mengzi-BERT-Base 74.69 75.35 57.76 61.64 82.41 77.93 88.16 82.20 67.04/88.35 83.74 70.70
ERNIE 1.0-Base-zh 74.17 74.84 58.91 62.25 81.68 76.58 85.20 82.77 67.32/87.83 82.47 69.68
HFL/RoBERTa-wwm-ext 74.11 74.60 58.08 61.23 81.11 76.92 88.49 80.77 68.39/88.50 83.43 68.03
BERT-Base-Chinese 72.57 74.63 57.13 61.29 80.97 75.22 81.91 81.90 65.30/86.53 82.01 65.38
UER/Chinese-RoBERTa-Base 71.78 72.89 57.62 61.14 80.01 75.56 81.58 80.80 63.87/84.95 81.52 62.76
8L512H UER/Chinese-RoBERTa-Medium 67.06 70.64 56.10 58.29 77.35 71.90 68.09 78.63 57.63/78.91 75.13 56.84
6L768H ERNIE 3.0-Medium-zh 72.49 73.37 57.00 60.67 80.64 76.88 79.28 81.60 65.83/87.30 79.91 69.73
HLF/RBT6, Chinese 70.06 73.45 56.82 59.64 79.36 73.32 76.64 80.67 62.72/84.77 78.17 59.85
TinyBERT6, Chinese 69.62 72.22 55.70 54.48 79.12 74.07 77.63 80.17 63.03/83.75 77.64 62.11
RoFormerV2 Small 68.52 72.47 56.53 60.72 76.37 72.95 75.00 81.07 62.97/83.64 67.66 59.41
UER/Chinese-RoBERTa-L6-H768 67.09 70.13 56.54 60.48 77.49 72.00 72.04 77.33 53.74/75.52 76.73 54.40
6L384H ERNIE 3.0-Mini-zh 66.90 71.85 55.24 54.48 77.19 73.08 71.05 79.30 58.53/81.97 69.71 58.60
4L768H HFL/RBT4, Chinese 67.42 72.41 56.50 58.95 77.34 70.78 71.05 78.23 59.30/81.93 73.18 56.45
4L512H UER/Chinese-RoBERTa-Small 63.25 69.21 55.41 57.552 73.64 69.80 66.78 74.83 46.75/69.69 67.59 50.92
4L384H ERNIE 3.0-Micro-zh 64.21 71.15 55.05 53.83 74.81 70.41 69.08 76.50 53.77/77.82 62.26 55.53
4L312H ERNIE 3.0-Nano-zh 62.97 70.51 54.57 48.36 74.97 70.61 68.75 75.93 52.00/76.35 58.91 55.11
TinyBERT4, Chinese 60.82 69.07 54.02 39.71 73.94 69.59 70.07 75.07 46.04/69.34 58.53 52.18
4L256H UER/Chinese-RoBERTa-Mini 53.40 69.32 54.22 41.63 69.40 67.36 65.13 70.07 5.96/17.13 51.19 39.68
3L1024H HFL/RBTL3, Chinese 66.63 71.11 56.14 59.56 76.41 71.29 69.74 76.93 58.50/80.90 71.03 55.56
3L768H HFL/RBT3, Chinese 65.72 70.95 55.53 59.18 76.20 70.71 67.11 76.63 55.73/78.63 70.26 54.93
2L128H UER/Chinese-RoBERTa-Tiny 44.45 69.02 51.47 20.28 59.95 57.73 63.82 67.43 3.08/14.33 23.57 28.12

AFQMC(语义相似度)、TNEWS(文本分类)、IFLYTEK(长文本分类)、CMNLI(自然语言推理)、OCNLI(自然语言推理)、CLUEWSC2020(代词消歧)、CSL(论文关键词识别)、CHID(成语阅读理解填空) 和 C3(中文多选阅读理解) 任务使用的评估指标均是 Accuracy。CMRC2018(阅读理解) 的评估指标是 EM (Exact Match)/F1,计算每个模型效果的平均值时,取 EM 为最终指标。

其中前 7 项属于分类任务,后面 3 项属于阅读理解任务,这两种任务的训练过程在下面将会分开介绍。

NOTE:具体评测方式如下

  1. 以上所有任务均基于 Grid Search 方式进行超参寻优。分类任务训练每间隔 100 steps 评估验证集效果,阅读理解任务每隔一个 epoch 评估验证集效果,取验证集最优效果作为表格中的汇报指标。

  2. 分类任务 Grid Search 超参范围: batch_size: 16, 32, 64; learning rates: 1e-5, 2e-5, 3e-5, 5e-5;因为 CLUEWSC2020 数据集较小,所以模型在该数据集上的效果对 batch_size 较敏感,所以对 CLUEWSC2020 评测时额外增加了 batch_size = 8 的超参搜索; 因为CLUEWSC2020 和 IFLYTEK 数据集对 dropout 概率值较为敏感,所以对 CLUEWSC2020 和 IFLYTEK 数据集评测时额外增加了 dropout_prob = 0.0 的超参搜索。

  3. 阅读理解任务 Grid Search 超参范围:batch_size: 24, 32; learning rates: 1e-5, 2e-5, 3e-5。阅读理解任务均使用多卡训练,其中 Grid Search 中的 batch_size 是指多张卡上的 batch_size 总和。

  4. 以上每个下游任务的固定超参配置如下表所示:

TASK AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL CMRC2018 CHID C3
epoch 3 3 3 2 5 50 5 2 3 8
max_seq_length 128 128 128 128 128 128 256 512 64 512
warmup_proportion 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.06 0.1
num_cards 1 1 1 1 1 1 1 2 4 4

不同预训练模型在下游任务上做 Grid Search 之后的最优超参(learning_rate、batch_size)如下:

Model AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL CMRC2018 CHID C3
ERNIE 1.0-Large-zh-cw 2e-5,64 3e-5,32 5e-5,16 2e-5,16 2e-5,32 1e-5,32 1e-5,16 2e-5,24 1e-5,24 2e-5,32
ERNIE 3.0-Xbase-zh 2e-5,16 3e-5,32 3e-5,32 3e-5,64 3e-5,64 2e-5,32 1e-5,16 3e-5,24 2e-5,24 3e-5,24
ERNIE 2.0-Large-zh 1e-5,32 3e-5,64 3e-5,32 2e-5,32 1e-5,16 3e-5,32 1e-5,64 2e-5,24 2e-5,24 3e-5,32
HFL/RoBERTa-wwm-ext-large 1e-5,32 3e-5,32 2e-5,32 1e-5,16 1e-5,16 2e-5,16 2e-5,16 3e-5,32 1e-5,24 2e-5,24
ERNIE 3.0-Base-zh 3e-5,16 3e-5,32 5e-5,32 3e-5,32 2e-5,64 2e-5,16 2e-5,32 2e-5,24 3e-5,24 3e-5,32
ERNIE 1.0-Base-zh-cw 2e-5,16 3e-5,32 5e-5,16 2e-5,16 3e-5,32 2e-5,16 2e-5,32 3e-5,24 2e-5,32 3e-5,24
ERNIE-Gram-zh 1e-5,16 5e-5,16 5e-5,16 2e-5,32 2e-5,64 3e-5,16 3e-5,64 3e-5,32 2e-5,24 2e-5,24
ERNIE 2.0-Base-zh 3e-5,64 3e-5,64 5e-5,16 5e-5,64 5e-5,32 5e-5,16 2e-5,16 2e-5,32 3e-5,24 3e-5,32
Langboat/Mengzi-Bert-Base 3e-5,32 5e-5,32 5e-5,16 2e-5,16 2e-5,16 3e-5,8 1e-5,16 3e-5,24 3e-5,24 2e-5,32
ERNIE 1.0-Base-zh 3e-5,16 3e-5,32 5e-5,16 5e-5,32 3e-5,16 2e-5,8 2e-5,16 3e-5,32 3e-5,24 3e-5,24
HFL/RoBERTa-wwm-ext 3e-5,32 3e-5,64 5e-5,16 3e-5,32 2e-5,32 3e-5,32 2e-5,32 3e-5,32 2e-5,32 3e-5,24
BERT-Base-Chinese 2e-5,16 5e-5,16 5e-5,16 5e-5,64 3e-5,16 3e-5,16 1e-5,16 3e-5,24 2e-5,32 3e-5,24
UER/Chinese-RoBERTa-Base 2e-5,16 5e-5,16 5e-5,16 2e-5,16 3e-5,16 3e-5,8 2e-5,16 3e-5,24 3e-5,32 3e-5,32
UER/Chinese-RoBERTa-Medium 3e-5,32 5e-5,64 5e-5,16 5e-5,32 3e-5,32 3e-5,16 5e-5,32 3e-5,24 3e-5,24 3e-5,32
ERNIE 3.0-Medium-zh 3e-5,32 3e-5,64 5e-5,32 2e-5,32 1e-5,64 3e-5,16 2e-5,32 3e-5,24 2e-5,24 1e-5,24
TinyBERT6, Chinese 1e-5,16 3e-5,32 5e-5,16 5e-5,32 3e-5,64 3e-5,16 3e-5,16 3e-5,32 3e-5,24 2e-5,24
RoFormerV2 Small 5e-5,16 2e-5,16 5e-5,16 5e-5,32 2e-5,16 3e-5,8 3e-5,16 3e-5,24 3e-5,24 3e-5,24
HLF/RBT6, Chinese 3e-5,16 5e-5,16 5e-5,16 5e-5,64 3e-5,16 3e-5,8 5e-5,64 2e-5,24 3e-5,32 2e-5,32
UER/Chinese-RoBERTa-L6-H768 2e-5,16 3e-5,16 5e-5,16 5e-5,16 5e-5,32 2e-5,32 3e-5,16 3e-5,32 3e-5,24 3e-5,24
ERNIE 3.0-Mini-zh 5e-5,64 5e-5,64 5e-5,16 5e-5,32 2e-5,16 2e-5,8 2e-5,16 3e-5,24 3e-5,24 3e-5,24
HFL/RBT4, Chinese 5e-5,16 5e-5,16 5e-5,16 5e-5,16 2e-5,16 2e-5,8 2e-5,16 3e-5,32 3e-5,24 3e-5,32
UER/Chinese-RoBERTa-Small 2e-5,32 5e-5,32 5e-5,16 5e-5,16 5e-5,16 2e-5,64 5e-5,32 3e-5,24 3e-5,24 3e-5,24
ERNIE 3.0-Micro-zh 3e-5,16 5e-5,32 5e-5,16 5e-5,16 2e-5,32 5e-5,16 3e-5,64 3e-5,24 3e-5,32 3e-5,24
ERNIE 3.0-Nano-zh 2e-5,32 5e-5,16 5e-5,16 5e-5,16 3e-5,16 1e-5,8 3e-5,32 3e-5,24 3e-5,24 2e-5,24
TinyBERT4, Chinese 3e-5,32 5e-5,16 5e-5,16 5e-5,16 3e-5,16 1e-5,16 5e-5,16 3e-5,24 3e-5,24 2e-5,24
UER/Chinese-RoBERTa-Mini 3e-5,16 5e-5,16 5e-5,16 5e-5,16 5e-5,32 3e-5,8 5e-5,32 3e-5,24 3e-5,32 3e-5,32
HFL/RBTL3, Chinese 5e-5,32 5e-5,16 5e-5,16 5e-5,32 2e-5,16 5e-5,8 2e-5,16 3e-5,24 2e-5,24 3e-5,24
HFL/RBT3, Chinese 5e-5,64 5e-5,32 5e-5,16 5e-5,16 2e-5,16 3e-5,16 5e-5,16 3e-5,32 3e-5,24 3e-5,32
UER/Chinese-RoBERTa-Tiny 5e-5,64 5e-5,16 5e-5,16 5e-5,16 5e-5,16 5e-5,8 5e-5,16 3e-5,24 3e-5,24 3e-5,24

其中,ERNIE 3.0-Base-zhERNIE 3.0-Medium-zhERNIE-Gram-zhERNIE 1.0-Base-zhERNIE 3.0-Mini-zhERNIE 3.0-Micro-zhERNIE 3.0-Nano-zhHFL/RBT3, ChineseHFL/RBTL3, ChineseHFL/RBT6, ChineseTinyBERT<sub>4</sub>, ChineseUER/Chinese-RoBERTa-BaseUER/Chinese-RoBERTa-MiniUER/Chinese-RoBERTa-Small 在 CLUEWSC2020 处的 dropout_prob 为 0.0,ERNIE 3.0-Base-zhHLF/RBT6, ChineseLangboat/Mengzi-BERT-BaseERNIE-Gram-zhERNIE 1.0-Base-zhTinyBERT6, ChineseUER/Chinese-RoBERTa-L6-H768ERNIE 3.0-Mini-zhERNIE 3.0-Micro-zhERNIE 3.0-Nano-zhHFL/RBT3, ChineseHFL/RBT4, ChineseHFL/RBT6, ChineseTinyBERT<sub>4</sub>, ChineseUER/Chinese-RoBERTa-MediumUER/Chinese-RoBERTa-BaseUER/Chinese-RoBERTa-MiniUER/Chinese-RoBERTa-TinyUER/Chinese-RoBERTa-Small 在 IFLYTEK 处的 dropout_prob 为 0.0。

一键复现模型效果

这一节将会对分类、阅读理解任务分别展示如何一键复现本文的评测结果。

启动 CLUE 分类任务

以 CLUE 的 TNEWS 任务为例,启动 CLUE 任务进行 Fine-tuning 的方式如下:

export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=TNEWS
export LR=3e-5
export BS=32
export EPOCH=6
export MAX_SEQ_LEN=128
export MODEL_PATH=ernie-3.0-medium-zh

cd classification
mkdir ernie-3.0-medium-zh
python -u ./run_clue_classifier.py \
    --model_name_or_path ${MODEL_PATH} \
    --task_name ${TASK_NAME} \
    --max_seq_length ${MAX_SEQ_LEN} \
    --batch_size ${BS}   \
    --learning_rate ${LR} \
    --num_train_epochs ${EPOCH} \
    --logging_steps 100 \
    --seed 42  \
    --save_steps  100 \
    --warmup_proportion 0.1 \
    --weight_decay 0.01 \
    --adam_epsilon 1e-8 \
    --output_dir ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/ \
    --device gpu  \
    --dropout 0.1 \
    --gradient_accumulation_steps 1 \
    --save_best_model True \
    --do_train \

另外,如需评估,传入参数 --do_eval 即可,如果只对读入的 checkpoint 进行评估不训练,则不需传入 --do_train

其中参数释义如下:

  • model_name_or_path 指示了 Fine-tuning 使用的具体预训练模型,可以是 PaddleNLP 提供的预训练模型,可以选择Transformer预训练模型汇总 中相对应的中文预训练权重。注意 CLUE 任务应选择中文预训练权重。
  • task_name 表示 Fine-tuning 的分类任务,当前支持 AFQMC、TNEWS、IFLYTEK、OCNLI、CMNLI、CSL、CLUEWSC2020。
  • max_seq_length 表示最大句子长度,超过该长度将被截断。
  • batch_size 表示每次迭代每张卡上的样本数目。
  • learning_rate 表示基础学习率大小,将于 learning rate scheduler 产生的值相乘作为当前学习率。
  • num_train_epochs 表示训练轮数。
  • logging_steps 表示日志打印间隔。
  • save_steps 表示模型保存及评估间隔。
  • save_best_model 是否保存在评估集上效果最好的模型,默认为 True
  • output_dir 表示模型保存路径。
  • device 表示训练使用的设备, 'gpu' 表示使用GPU, 'xpu' 表示使用百度昆仑卡, 'cpu' 表示使用 CPU。

Fine-tuning 过程将按照 logging_stepssave_steps 的设置打印出如下日志:

global step 100/20010, epoch: 0, batch: 99, rank_id: 0, loss: 2.734340, lr: 0.0000014993, speed: 8.7969 step/s
eval loss: 2.720359, acc: 0.0827, eval done total : 25.712125062942505 s
global step 200/20010, epoch: 0, batch: 199, rank_id: 0, loss: 2.608563, lr: 0.0000029985, speed: 2.5921 step/s
eval loss: 2.652753, acc: 0.0945, eval done total : 25.64827537536621 s
global step 300/20010, epoch: 0, batch: 299, rank_id: 0, loss: 2.555283, lr: 0.0000044978, speed: 2.6032 step/s
eval loss: 2.572999, acc: 0.112, eval done total : 25.67190170288086 s
global step 400/20010, epoch: 0, batch: 399, rank_id: 0, loss: 2.631579, lr: 0.0000059970, speed: 2.6238 step/s
eval loss: 2.476962, acc: 0.1697, eval done total : 25.794789791107178 s

使用 Trainer 启动 CLUE 分类任务

PaddleNLP 提供了 Trainer API,本示例新增了run_clue_classifier_trainer.py脚本供用户使用。

export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=TNEWS
export LR=3e-5
export BS=32
export EPOCH=6
export MAX_SEQ_LEN=128
export MODEL_PATH=ernie-3.0-medium-zh

cd classification
mkdir ernie-3.0-medium-zh

python -u ./run_clue_classifier_trainer.py \
    --model_name_or_path ${MODEL_PATH} \
    --dataset "clue ${TASK_NAME}" \
    --max_seq_length ${MAX_SEQ_LEN} \
    --per_device_train_batch_size ${BS}   \
    --per_device_eval_batch_size ${BS}   \
    --learning_rate ${LR} \
    --num_train_epochs ${EPOCH} \
    --logging_steps 100 \
    --seed 42  \
    --save_steps 100 \
    --warmup_ratio 0.1 \
    --weight_decay 0.01 \
    --adam_epsilon 1e-8 \
    --output_dir ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/ \
    --device gpu  \
    --do_train \
    --do_eval \
    --metric_for_best_model "eval_accuracy" \
    --load_best_model_at_end \
    --save_total_limit 3 \

大部分参数含义如上文所述,这里简要介绍一些新参数:

  • dataset, 同上文task_name,此处为小写字母。表示 Fine-tuning 的分类任务,当前支持 afamc、tnews、iflytek、ocnli、cmnli、csl、cluewsc2020。
  • per_device_train_batch_size 同上文batch_size。训练时,每次迭代每张卡上的样本数目。
  • per_device_eval_batch_size 同上文batch_size。评估时,每次迭代每张卡上的样本数目。
  • warmup_ratio 同上文warmup_proportion,warmup步数占总步数的比例。
  • metric_for_best_model 评估时,最优评估指标。
  • load_best_model_at_end 训练结束时,时候加载评估结果最好的 ckpt。
  • save_total_limit 保存的ckpt数量的最大限制

启动 CLUE 阅读理解任务

以 CLUE 的 C3 任务为例,多卡启动 CLUE 任务进行 Fine-tuning 的方式如下:

cd mrc

MODEL_PATH=ernie-3.0-medium-zh
BATCH_SIZE=6
LR=2e-5

python -m paddle.distributed.launch --gpus "0,1,2,3" run_c3.py \
    --model_name_or_path ${MODEL_PATH} \
    --batch_size ${BATCH_SIZE} \
    --learning_rate ${LR} \
    --max_seq_length 512 \
    --num_train_epochs 8 \
    --do_train \
    --warmup_proportion 0.1 \
    --gradient_accumulation_steps 3 \

需要注意的是,如果显存无法容纳所传入的 batch_size,可以通过传入 gradient_accumulation_steps 参数来模拟该 batch_size

批量启动 Grid Search

环境依赖

Grid Search 需要在 GPU 环境下进行,需要注意的是 C3 任务需要显存大于 16 GB,最好是在显存 32 GB的环境下启动。

Grid Search 中的 GPU 调度需要依赖 pynvml 库,pynvml 库提供了 GPU 管理的 Python 接口。可启动以下命令进行安装 pynvml:

pip install pynvml

一键启动方法

运行下面一句命令即可启动 Grid Search 任务。前期需要注意数据集是否正常下载,否则训练任务不会正式启动。 脚本默认不保存模型,如需保存每个超参数下最好的模型,需要修改 Python 脚本中的 --save_best_models 参数为 True。

cd grid_search_tools

# 这里 ernie-3.0-base-zh 是模型名,也可以传用户自定义的模型目录
# 自定义的模型目录需要有 model_config.json, model_state.pdparams, tokenizer_config.json 和 vocab.txt 四个文件
python grid_seach.py ernie-3.0-base-zh

确认模型所有任务训练完成后,可以调用脚本 extract_result.sh 一键抽取 Grid Search 结果,打印出每个任务的最佳结果和对应的超参数,例如:

bash extract_result.sh ernie-3.0-base-zh
AFQMC	TNEWS	IFLYTEK	CMNLI	OCNLI	CLUEWSC2020	CSL	CMRC2018	CHID	C3
75.93	58.26	61.56	83.02	80.10	86.18	82.63	70.71/90.41	84.26	77.88
====================================================================
Best hyper-parameters list:
====================================================================
TASK	result	(lr, batch_size, dropout_p)
AFQMC	75.93	(3e-05,16,0.1)
TNEWS	58.26	(3e-05,32,0.1)
IFLYTEK	61.56	(5e-05,32,0.0)
CMNLI	83.02	(3e-05,32,0.1)
OCNLI	80.10	(2e-05,64,0.1)
CLUEWSC2020	86.18	(2e-05,16,0.0)
CSL	82.63	(2e-05,32,0.1)
CMRC2018	70.71/90.41	(2e-05,24,0.1)
CHID	84.26	(3e-05,24,0.1)
C3	77.88	(3e-05,32,0.1)

另外,如遇意外情况(如机器重启)导致训练中断,可以直接再次启动 grid_search.py 脚本,之前已完成(输出完整日志)的任务则会直接跳过。

Grid Search 脚本说明

本节介绍 grid_search_tools 目录下各个脚本的功能:

  • grid_search.py Grid Search 任务入口脚本,该脚本负责调度 GPU 资源,可自动将 7 个分类任务、3 个阅读理解下所有超参数对应的任务完成,训练完成后会自动调用抽取结果的脚本 extract_result.sh 打印出所有任务的最佳结果和对应的超参。
  • warmup_dataset_and_model.py 首次运行时,该脚本完成模型下载(如需)、数据集下载,阅读理解任务数据预处理、预处理文件缓存等工作,再次运行则会检查这些文件是否存在,存在则跳过。该脚本由 grid_search.py 在 Grid Search 训练前自动调用,预处理 cache 文件生成后,后面所有训练任务即可加载缓存文件避免重复进行数据预处理。如果该阶段任务失败,大多需要检查网络,解决之后需重启 grid_search.py,直到训练正常开始。该脚本也可手动调用,需要 1 个参数,模型名称或目录。该脚本在使用 Intel(R) Xeon(R) Gold 6271C CPU 且 --num_proc默认为 4 的情况下需约 30 分钟左右完成,可以更改 run_mrc.sh 中的 --num_proc 参数以改变生成 cache 的进程数。需要注意的是,若改变 num_proc,之前的缓存则不能再使用,该脚本会重新处理数据并生成新的 cache,cache 相关内容可查看datasets.Dataset.map文档
  • extract_result.sh 从日志抽取每个任务的最佳结果和对应的最佳超参并打印,grid_search.py 在完成训练任务后会自动调用,也可手动调用,需要 1 个参数:模型名称或目录。手动调用前需要确认训练均全部完成,并且保证该目录下有分类和阅读理解所有任务的日志。
  • run_mrc.sh 阅读理解任务的启动脚本。
  • run_cls.sh 分类任务的启动脚本。

参加 CLUE 竞赛

对各个任务运行预测脚本,汇总多个结果文件压缩之后,即可提交至 CLUE 官网进行评测。

下面 2 小节会分别介绍分类、阅读理解任务产生预测结果的方法。

分类任务

以 TNEWS 为例,可以直接使用脚本 classification/run_clue_classifier.py 对单个任务进行预测,注意脚本启动时需要传入参数 --do_predict。假设 TNEWS 模型所在路径为 ${TNEWS_MODEL},运行如下脚本可得到模型在测试集上的预测结果,预测结果会写入地址 ${OUTPUT_DIR}/tnews_predict.json

cd classification
OUTPUT_DIR=results
mkdir ${OUTPUT_DIR}

python run_clue_classifier.py \
    --task_name TNEWS \
    --model_name_or_path ${TNEWS_MODEL}  \
    --output_dir ${OUTPUT_DIR} \
    --do_predict \

阅读理解任务

以 C3 为例,直接使用 mrc/run_c3.py对该任务进行预测,注意脚本启动时需要传入参数 --do_predict。假设 C3 模型所在路径为 ${C3_MODEL},运行如下脚本可得到模型在测试集上的预测结果,预测结果会写入地址 ${OUTPUT_DIR}/c311_predict.json

cd mrc
OUTPUT_DIR=results
mkdir ${OUTPUT_DIR}

python run_c3.py \
    --model_name_or_path ${C3_MODEL} \
    --output_dir ${OUTPUT_DIR} \
    --do_predict \