Skip to content

JunnYu/GPLinker_pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GPLinker_pytorch

GPLinker_pytorch

介绍

这是pytorch版本的GPLinker代码以及TPLinker_Plus代码。

更新

  • 2022/03/03 添加tplinker_plus+bert-base-chinese权重在duie_v1上的结果。添加duee_v1任务的训练代码,请查看duee_v1目录
  • 2022/03/01 添加tplinker_plus+hfl/chinese-roberta-wwm-ext权重在duie_v1上的结果。
  • 2022/02/25 现已在Dev分支更新最新的huggingface全家桶版本的代码,main分支是之前旧的代码(执行效率慢)

结果

Tips: 在RTX309020epoch的条件下,gplinker需要训练5-6htplinker_plus则需要训练16-17h

dataset method pretrained_model_name_or_path f1 precision recall
duie_v1 gplinker hfl/chinese-roberta-wwm-ext 0.8214065255731926 0.8250077498782166 0.8178366038895478
duie_v1 gplinker bert-base-chinese 0.8198087178424598 0.8146470447994109 0.8250362175688137
duie_v1 tplinker_plus hfl/chinese-roberta-wwm-ext 0.8256425523469291 0.8295114656031908 0.8218095614381671
duie_v1 tplinker_plus bert-base-chinese 0.8216261688290682 0.8076458240569943 0.8360990385881737

Tensorboard日志

gplinker训练日志

tplinker_plus训练日志

依赖

所需的依赖如下:

  • fastcore==1.3.29
  • datasets==1.18.3
  • transformers>=4.16.2
  • accelerate==0.5.1
  • chinesebert==0.2.1

安装依赖requirements.txt

pip install -r requirements.txt

准备数据

http://ai.baidu.com/broad/download?dataset=sked 下载数据。

train_data.jsondev_data.json压缩成spo.zip文件,并且放入data文件夹。

当前data/spo.zip文件是本人提供精简后的数据集,其中train_data.json只有2000条数据,dev_data.json只有200条数据。

运行

accelerate launch train.py \
    --model_type bert \
    --pretrained_model_name_or_path bert-base-chinese \
    --method gplinker \
    --logging_steps 200 \
    --num_train_epochs 20 \
    --learning_rate 3e-5 \
    --num_warmup_steps_or_radios 0.1 \
    --gradient_accumulation_steps 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 32 \
    --seed 42 \
    --save_steps 10804 \
    --output_dir ./outputs \
    --max_length 128 \
    --topk 1 \
    --num_workers 6

其中使用到参数介绍如下:

  • model_type: 表示模型架构类型,像bert-base-chinesehfl/chinese-roberta-wwm-ext模型都是基于bert架构,junnyu/roformer_chinese_char_base是基于roformer架构,可选择["bert", "roformer", "chinesebert"]
  • pretrained_model_name_or_path: 表示加载的预训练模型权重,可以是本地目录,也可以是huggingface.co的路径。
  • method: 表示使用的方法, 可选择["gplinker", "tplinker_plus"]
  • logging_steps: 日志打印的间隔,默认为200
  • num_train_epochs: 训练轮数,默认为20
  • learning_rate: 学习率,默认为3e-5
  • num_warmup_steps_or_radios: warmup步数或者比率,当为浮点类型时候表示的是radio,当为整型时候表示的是step,默认为0.1
  • gradient_accumulation_steps: 梯度累计的步数,默认为1
  • per_device_train_batch_size: 训练的batch_size,默认为16
  • per_device_eval_batch_size: 评估的batch_size,默认为32
  • seed: 随机种子,以便于复现,默认为42
  • save_steps: 保存步数,每隔多少步保存模型。
  • output_dir: 模型输出路径。
  • max_length: 句子的最大长度,当大于这个长度时候,tokenizer会进行截断处理。
  • topk: 保存topk个数模型,默认为1
  • num_workers: dataloadernum_workers参数,linux系统下发现GPU使用率不高的时候可以尝试设置这个参数大于0,而windows下最好设置为0,不然会报错。
  • use_efficient: 是否使用EfficientGlobalPointer,默认为False

Reference

About

GPLinker_pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published