forked from THUDM/GLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune.py
120 lines (108 loc) · 5.43 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
from collections import OrderedDict
from finetune_glm import finetune
from tasks.superglue.dataset import SuperGlueDataset, MultiChoiceDataset, PROCESSORS, get_output_func
from tasks.superglue.dataset import CLASSIFICATION_DATASETS, MULTI_CHOICE_DATASETS
from tasks.superglue.evaluate import qa_exact_match, qa_f1, multirc_em, squad_exact_match, squad_f1
from tasks.superglue.pvp import PVPS
from tasks.eval_utils import accuracy_func_provider
from tasks.eval_utils import accuracy_metric, f1_macro_metric, f1_metric
from glob import glob
DEFAULT_METRICS = {
"record": [("EM", qa_exact_match), ("F1", qa_f1)],
"copa": [("accuracy", accuracy_metric)],
"rte": [("accuracy", accuracy_metric)],
"boolq": [("accuracy", accuracy_metric)],
"wic": [("accuracy", accuracy_metric)],
"wsc": [("accuracy", accuracy_metric)],
"cb": [("accuracy", accuracy_metric), ("f1-macro", f1_macro_metric)],
"multirc": [("f1a", f1_metric), ("em", multirc_em), ("acc", accuracy_metric)],
"mnli": [("accuracy", accuracy_metric)],
"sst2": [("accuracy", accuracy_metric)],
"qnli": [("accuracy", accuracy_metric)],
"qqp": [("accuracy", accuracy_metric)],
"mrpc": [("accuracy", accuracy_metric)],
"cola": [("accuracy", accuracy_metric)],
"squad": [("accuracy", accuracy_metric)],
"afqmc": [("accuracy", accuracy_metric)],
"tnews": [("accuracy", accuracy_metric)],
"cluewsc": [("accuracy", accuracy_metric)],
"cmrc": [("accuracy", accuracy_metric)],
"multichoice": [("accuracy", accuracy_metric)]
}
def train_valid_datasets_provider(args, tokenizer, pattern_text=False):
"""Provide train and validation datasets."""
task_name = args.task.lower()
data_dir = args.data_dir
train_dataset = SuperGlueDataset(args, task_name, data_dir, args.seq_length, "train", tokenizer,
pattern_text=pattern_text)
valid_dataset = SuperGlueDataset(args, task_name, data_dir, args.seq_length, "dev", tokenizer, for_train=True,
pattern_text=pattern_text)
return train_dataset, valid_dataset
def metrics_func_provider(args, tokenizer, is_test):
"""Privde metrics callback function."""
def single_dataset_provider(split):
if args.task == "multichoice":
return MultiChoiceDataset(args, split, tokenizer, args.seq_length)
else:
return SuperGlueDataset(args, args.task.lower(), args.data_dir, args.seq_length, split, tokenizer)
output_func = get_output_func(args.task.lower(), args)
eval_func = None
if args.task.lower() == 'wsc' and args.cloze_eval and not args.wsc_negative:
from tasks.language_model.finetune import classify_evaluate
eval_func = classify_evaluate
metric_dict = OrderedDict(DEFAULT_METRICS[args.task.lower()])
return accuracy_func_provider(single_dataset_provider, metric_dict, args, is_test=is_test, eval_func=eval_func,
output_func=output_func, only_rank0=False, tokenizer=tokenizer)
def main(args):
model_kwargs = {}
if args.task.lower() != "multichoice":
processor = PROCESSORS[args.task.lower()](args)
pvp = PVPS[args.task.lower()](args, None, processor.get_labels(), args.seq_length,
pattern_id=args.pattern_id, is_multi_token=args.multi_token,
num_prompt_tokens=args.num_prompt_tokens)
else:
patterns = args.test_data
datapaths = []
for pattern in patterns:
for path in glob(pattern, recursive=True):
datapaths.append(path)
args.test_data = datapaths
if args.continuous_prompt:
model_kwargs["spell_length"] = pvp.spell_length
if args.task.lower() == 'wsc' and args.cloze_eval and not args.wsc_negative:
from tasks.language_model.finetune import lm_forward_step
finetune(args, train_valid_datasets_provider, model_kwargs,
end_of_epoch_callback_provider=metrics_func_provider, forward_step=lm_forward_step)
else:
if args.task.lower() == "multichoice":
multi_token = True
elif args.cloze_eval:
multi_token = pvp.is_multi_token
else:
multi_token = args.task.lower() in MULTI_CHOICE_DATASETS
args.multi_token = multi_token
if not multi_token:
model_kwargs["model_type"] = "multiple_choice" if args.cloze_eval else "classification"
model_kwargs["multi_token"] = False
model_kwargs["num_labels"] = len(processor.get_labels())
else:
model_kwargs["model_type"] = "multiple_choice"
model_kwargs["multi_token"] = True
model_kwargs["num_labels"] = 1
finetune(args, train_valid_datasets_provider, model_kwargs,
end_of_epoch_callback_provider=metrics_func_provider)