-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpred_multi_with_bert.py
103 lines (92 loc) · 2.48 KB
/
pred_multi_with_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
from transformers import (
BertTokenizer,
Trainer,
TrainingArguments,
)
from train_multi_with_bert import (
emotion_labels,
read_corpus,
MultiTaskEmotionDataset,
BertForMultiTaskSequenceClassification,
compute_metrics,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--test_jsonl',
default='dialogues-test.jsonl'
)
parser.add_argument(
'--model_path',
default='./results'
)
parser.add_argument(
'--for_exper',
action='store_true',
)
parser.add_argument(
'--for_speaker',
action='store_true',
)
parser.add_argument(
'--max_length',
default=512,
type=int,
)
args = parser.parse_args()
tokenizer = BertTokenizer.from_pretrained(
args.model_path,
)
test_texts, test_expr_labels = read_corpus(
args.test_jsonl,
for_exper=False,
for_speaker=args.for_speaker,
sep_token=tokenizer.sep_token,
)
_, test_exper_labels = read_corpus(
args.test_jsonl,
for_exper=True,
for_speaker=args.for_speaker,
sep_token=tokenizer.sep_token,
)
test_encodings = tokenizer(
test_texts,
truncation=True,
padding=True,
max_length=args.max_length,
)
test_dataset = MultiTaskEmotionDataset(
test_encodings,
test_expr_labels,
test_exper_labels,
sep_token_id=tokenizer.sep_token_id,
)
training_args = TrainingArguments(
output_dir=args.model_path,
# evaluation_strategy='epoch',
# per_device_train_batch_size=args.per_device_train_batch_size,
# per_device_eval_batch_size=args.per_device_eval_batch_size,
# gradient_accumulation_steps=args.gradient_accumulation_steps,
# weight_decay=0.01,
# num_train_epochs=args.num_train_epochs,
# warmup_steps=500,
# logging_strategy='epoch',
# save_strategy='no',
)
model = BertForMultiTaskSequenceClassification.from_pretrained(
args.model_path,
num_labels=len(emotion_labels),
problem_type='regression',
)
trainer = Trainer(
model=model,
args=training_args,
# train_dataset=train_dataset,
# eval_dataset=val_dataset,
# tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
print(trainer.predict(test_dataset).metrics)
if __name__ == '__main__':
main()