-
Notifications
You must be signed in to change notification settings - Fork 0
/
qa_finetuning_hyp_search.py
515 lines (426 loc) · 21.2 KB
/
qa_finetuning_hyp_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import argparse
import json
import os
from collections import defaultdict, OrderedDict
from pathlib import Path
import numpy as np
from datasets import load_dataset, load_metric
from tqdm import tqdm
from transformers import AutoModelForQuestionAnswering, TrainingArguments, set_seed
from transformers import AutoTokenizer
from transformers import default_data_collator
from trainer_qa import QuestionAnsweringTrainer
os.environ["WANDB_LOG_MODEL"] = "false"
import wandb
SPLITS = ['train', 'dev', 'test']
SEED = 123
squad_v2 = False # IF USING THE SECOND VERSION OF THE DATASET, WHERE THE IMPOSSIBLE ANSWER IS POSSIBLE
def postprocess_qa_predictions(
examples, features,
raw_predictions,
tokenizer,
n_best_size=20, max_answer_length=30,
):
all_start_logits, all_end_logits = raw_predictions
# MAP EACH FEATURE TO THEIR EXAMPLE USING example_id
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# DICTIONARIES TO FILL
predictions = OrderedDict()
print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# ITERATE OVER ALL THE EXAMPLES
for example_index, example in enumerate(tqdm(examples)):
# FEATURES ASSOCIATED TO THE CURRENT EXAMPLE
feature_indices = features_per_example[example_index]
min_null_score = None # Only used if squad_v2 is True.
valid_answers = []
context = example["context"]
# ITERATE OVER FEATURES
for feature_index in feature_indices:
# MODEL PREDICTIONS FOR THIS FEATURE
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
# MAP OF THE LOGITS TO THE ORIGINAL TEXT
offset_mapping = features[feature_index]["offset_mapping"]
# IN CASE ONE FEATURE PREDICTS THE CLS TOKEN AS THE ANSWER (BECAUSE THE ANSWER IS NOT IN ITS CONTEXT),
# SET IT AS THE MIN SCORE
cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
feature_null_score = start_logits[cls_index] + end_logits[cls_index]
if min_null_score is None or min_null_score < feature_null_score:
min_null_score = feature_null_score
# HOW TO CHOOSE 2ND BEST ANSWER? 2ND BEST BEGINNING AND END?
# BEST BEGINNING AND 2ND BEST END?
# TO SOLVE: CALCULATE COMBINED SCORE OF ALL COMBINATIONS OF BEGINNING AND END
start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# SKIP ANSWERS WHOSE INDEXES POINT TO TOKENS OUTSIDE THE CONTEXT OR OUT-OF-BOUNDS
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# SKIP ANSWERS WITH HIGHER THAN MAX LENGTH, OR WITH LOWER LENGTH THAN ZERO
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
valid_answers.append(
{
# COMBINED SCORE
"score": start_logits[start_index] + end_logits[end_index],
"text": context[start_char: end_char],
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
"start_logit": start_logits[start_index],
"end_logit": end_logits[end_index],
}
)
if valid_answers:
# SORT ANSWERS FROM HIGHEST SCORING TO LOWEST
best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
else:
# IN CASE NO ANSWER WAS NON-NULL, CREATE A FAKE ONE WITH SCORE ZERO
best_answer = {"text": "", "score": 0.0}
# PICK BEST NON-NULL ANSWER
# IF squad_v2 = True, PICK IMPOSSIBLE ANSWER (CLS) IF ALL FEATURES GIVE IT A HIGH SCORE
if not squad_v2:
predictions[example["id"]] = best_answer["text"]
else:
answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
predictions[example["id"]] = answer
# ADAPT THE FORMAT OF THE JSON TO FIT THE REQUIREMENTS FOR THE METRIC USED (SQUAD)
if squad_v2:
formatted_predictions = [{"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()]
else:
# PREDICTIONS FORMAT: [ {'id': int, 'prediction_text': String}, ... ]
formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()]
return predictions, formatted_predictions
# --------------------------------------------------------------------------------------------------------
# method
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'train/best_eval_f1',
'goal': 'maximize'
}
}
# hyperparameters
parameters_dict = {
'epochs': {
'value': 1
},
'batch_size': {
'values': [32, 64]
},
'learning_rate': {
'values': [75e-6, 5e-5, 2e-5, 1e-5]
},
'weight_decay': {
'values': [0.0, 0.001, 0.1, 0.25, 0.5]
},
'warmup_ratio': {
'values': [0.05]
}
}
sweep_config['parameters'] = parameters_dict
# --------------------------------------------------------------------------------------------------------
def main(
dataset_name: str | None,
data_folder: Path,
output_folder: Path,
model_path: str,
run_count: int,
logging_dir: Path,
seed: int = SEED,
do_hyperparameter_search: bool = False
):
# -------------------------------------------
# 0.- PREPARING THE ENVIRONMENT
# -------------------------------------------
set_seed(seed)
dataset_name = dataset_name or Path(data_folder).name
output_folder = output_folder / dataset_name / model_path / str(run_count)
output_folder.mkdir(parents=True, exist_ok=True)
if do_hyperparameter_search:
logging_dir = logging_dir / 'hyperparameter_search' / dataset_name / model_path / str(run_count)
logging_dir.mkdir(parents=True, exist_ok=True)
os.environ["WANDB_DIR"] = str(logging_dir)
sweep_config['parameters']['dataset'] = {'value': dataset_name}
# -------------------------------------------
# 1.- LOADING DATASET
# -------------------------------------------
data_files = {split: str(data_folder / f'{split}.json') for split in SPLITS}
dataset = load_dataset('json', data_files=data_files, field='data')
print(dataset)
# -------------------------------------------
# 2.- TOKENIZE THE TEXTS
# -------------------------------------------
# 2.1.- Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 2.2.- Prepare text and handle long contexts with a sliding window approach
def prepare_features(examples, inference=False):
try:
max_length = tokenizer.model_max_length
except:
max_length = 512
doc_stride = 128
# Tokenize just the question (to get the length)
# <s>question</s>context</s>
# If the context is too long, it will be split into multiple samples (fragments). Ej:
# <s>question</s>context1</s>
# <s>question</s>context2</s>
# The question will be the same for all the samples and won't be truncated
examples["question"] = [q.strip() for q in examples["question"]]
tokenized_examples = tokenizer(
examples["question"],
examples["context"],
truncation="only_second",
max_length=max_length,
stride=doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
# -------------------------------------------------------------------------------------------
# Get the start and end positions of the answer and prepare the features for training or evaluation
# -------------------------------------------------------------------------------------------
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
if not inference:
# FOR EVALUATION FEATURES
# TOKEN TO CHARACTER POSITION MAP (IN THE ORIGINAL EXAMPLE)
offset_mapping = tokenized_examples.pop("offset_mapping")
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
# LABEL IMPOSIBLE ANSWERS WITH THE INDEX OF THE CLS TOKEN
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
# GRAB THE SEQUENCE THAT DETERMINES WHICH TOKEN IS FROM THE CONTEXT AND WHICH FROM THE QUESTION
sequence_ids = tokenized_examples.sequence_ids(i)
sample_index = sample_mapping[i]
answers = examples["answer"][sample_index]
'''
# If no answers are given, set the cls_index as answer.
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
'''
# STARTING AND ENDING INDEX OF THE ANSWER
start_char = answers["answer_start"]
end_char = start_char + len(answers["text"])
# MOVE START INDEX INTO A POSITION IN THIS SPAN OF TEXT
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1
# MOVE END INDEX INTO A POSITION IN THIS SPAN OF TEXT
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
# ANSWER GIVEN IS OUT OF BOUNDS, SET POSITIONS AS CLS INDEX
if not (
offsets[token_start_index][0] <= start_char
and offsets[token_end_index][1] >= end_char
):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# MOVE START TOKEN INDEX TO THE BEGINNING OF THE ANSWER
while (
token_start_index < len(offsets)
and offsets[token_start_index][0] <= start_char
):
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
# MOVE END TOKEN INDEX TO THE END OF THE ANSWER
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)
return tokenized_examples
else:
# FOR EVALUATION FEATURES
tokenized_examples["example_id"] = []
# ITERATE OVER EXAMPLES
for i in range(len(tokenized_examples["input_ids"])):
# STORE SEQUENCE CORRESPONDING TO THE EXAMPLE
sequence_ids = tokenized_examples.sequence_ids(i)
context_index = 1 # if pad_on_right else 0
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]
return tokenized_examples
# PREPARE THE TRAINING DATASET FEATURES FOR THE TRAINER
training_features = dataset['train'].map(prepare_features, batched=True, remove_columns=dataset["train"].column_names)
# PREPARE THE EVALUATION DATASET FEATURES FOR THE TRAINER
eval_features = dataset["dev"].map(lambda _batch: prepare_features(_batch, inference=True), batched=True, remove_columns=dataset["dev"].column_names)
test_features = dataset["test"].map(lambda batch: prepare_features(batch, inference=True), batched=True, remove_columns=dataset["test"].column_names)
# -------------------------------------------
# 3.- FINE-TUNING THE MODEL
# -------------------------------------------
# POST-PROCESS THE RAW-PREDICTIONS OBTAINED BY THE MODEL
def trainer_post_process(examples, features, model_predictions):
"""
The trainer will use this function to post-process the raw predictions
It expect the next parameters
- eval_examples: Subset of the examples used for evaluation (the og dataset)
- eval_dataset/features: The evaluation dataset
- output.predictions
"""
_, formatted_preds = postprocess_qa_predictions(
examples=examples, features=features,
raw_predictions=model_predictions, tokenizer=tokenizer
)
refs = [{"id": str(ex["id"]), "answers": [ex['answer']]} for ex in examples]
return formatted_preds, refs
# -------------------------------------------
# PREPARE THE EVALUATION METRIC
metric = load_metric("squad")
def compute_metrics(p: tuple):
preds, refs = p
return metric.compute(predictions=preds, references=refs)
# -------------------------------------------
if do_hyperparameter_search:
def model_init():
return AutoModelForQuestionAnswering.from_pretrained(model_path)
def hyperpameter_search_training(config=None):
with wandb.init(config=config):
# set sweep configuration
config = wandb.config
gradient_accumulation_steps = 1
_batch_size = config.batch_size
# Divide learning rate by 2 if batch size is bigger than 26
# and multiply gradient accumulation steps by 2
while _batch_size > 32:
_batch_size /= 2
gradient_accumulation_steps *= 2
_batch_size = int(_batch_size)
gradient_accumulation_steps = int(gradient_accumulation_steps)
# set training arguments
training_args = TrainingArguments(
output_dir=str(output_folder),
report_to='wandb', # Turn on Weights & Biases logging
num_train_epochs=config.epochs,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
per_device_train_batch_size=_batch_size,
warmup_ratio=config.warmup_ratio,
per_device_eval_batch_size=32,
save_strategy='steps',
evaluation_strategy='steps',
save_steps=240,
eval_steps=240,
logging_strategy='steps',
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="f1",
lr_scheduler_type="linear",
fp16=True,
logging_dir=str(logging_dir),
save_total_limit=2,
gradient_accumulation_steps=gradient_accumulation_steps,
)
# define training loop
trainer = QuestionAnsweringTrainer(
# model,
model_init=model_init,
args=training_args,
train_dataset=training_features,
eval_dataset=eval_features,
data_collator=default_data_collator,
tokenizer=tokenizer,
eval_examples=dataset["dev"],
post_process_function=trainer_post_process,
compute_metrics=compute_metrics,
)
# start training loop
trainer.train() # train
trainer.evaluate(metric_key_prefix='best_eval') # dev
trainer.evaluate(
eval_dataset=test_features,
eval_examples=dataset["test"],
metric_key_prefix="test"
) # test
sweep_id = wandb.sweep(sweep_config, project='Applications1-QA')
wandb.agent(sweep_id, hyperpameter_search_training, count=10)
else:
# LOAD THE MODEL
model = AutoModelForQuestionAnswering.from_pretrained(model_path)
args = TrainingArguments(
str(output_folder),
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=2,
lr_scheduler_type="linear",
load_best_model_at_end=True,
metric_for_best_model="f1",
save_total_limit=2,
)
data_collator = default_data_collator
trainer = QuestionAnsweringTrainer(
model,
args,
train_dataset=training_features,
eval_dataset=eval_features,
data_collator=data_collator,
tokenizer=tokenizer,
eval_examples=dataset["dev"],
post_process_function=trainer_post_process,
compute_metrics=compute_metrics,
)
trainer.train()
# -------------------------------------------
# 4.- EVALUATION
# -------------------------------------------
# PREPARE THE VALIDATION DATA TO EXTRACT THE ANSWER FROM THE TEXT
test_features = dataset["test"].map(lambda batch: prepare_features(batch, inference=True), batched=True, remove_columns=dataset["test"].column_names)
# PREDICTIONS FOR ALL FEATURES
raw_predictions = trainer.predict(predict_examples=dataset["test"], predict_dataset=test_features)
# BY DEFAULT, Trainer HIDES UNUSED COLUMNS, SO WE SHOW THEM AGAIN
print(test_features.set_format(type=test_features.format["type"], columns=list(test_features.features.keys())))
test_features.set_format(type=test_features.format["type"], columns=list(test_features.features.keys()))
''' FUNCION FINAL '''
# POST-PROCESS THE RAW-PREDICTIONS OBTAINED BY THE MODEL
final_predictions, formatted_predictions = postprocess_qa_predictions(
examples=dataset["test"], features=test_features,
raw_predictions=raw_predictions.predictions, tokenizer=tokenizer
)
print(formatted_predictions)
references_aux = [{"id": ex["id"], "answers": ex["answer"]} for ex in dataset["test"]]
# REFERENCES FORMAT: [ {'id': int, 'answers': {'answer_start': int, 'text': String} }, ... ]
references = []
for ref in references_aux:
answer_start = [ref['answers']['answer_start']]
text = [ref['answers']['text']]
new_ref = {
'id': ref['id'],
'answers': {'answer_start': answer_start, 'text': text}
}
references.append(new_ref)
print(references)
metric.compute(predictions=formatted_predictions, references=references)
with open(output_folder / "test_results.json", "w", encoding='utf8') as f:
json.dump(metric.compute(predictions=formatted_predictions, references=references), f, indent=4, ensure_ascii=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, required=False, default=None)
parser.add_argument('--data_folder', type=Path, required=False, default=Path('/tartalo03/users/udixa/qa_applications1/data/COMBINED'))
parser.add_argument('--output_folder', type=Path, required=False, default=Path('/gscratch4/users/idelaiglesia004/qa_applications1'))
parser.add_argument('--logging_dir', type=Path, required=False, default=Path('/gscratch4/users/idelaiglesia004/qa_applications1/logs'))
parser.add_argument('--model_path', type=str, required=False, default='distilbert-base-uncased')
parser.add_argument('--seed', type=int, required=False, default=SEED)
parser.add_argument('--run_count', type=int, required=False, default=1)
parser.add_argument('--do_hyperparameter_search', action='store_true', default=False)
args = parser.parse_args()
main(**vars(args))