forked from microsoft/nlp-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathabstractive_summarization_unilm_cnndm.py
125 lines (97 loc) · 3.07 KB
/
abstractive_summarization_unilm_cnndm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import datetime
import argparse
import jsonlines
import torch
from utils_nlp.models.transformers.abstractive_summarization_seq2seq import (
S2SAbsSumProcessor,
S2SAbstractiveSummarizer
)
from utils_nlp.eval import compute_rouge_python
parser = argparse.ArgumentParser()
parser.add_argument(
"--local_rank", type=int, default=-1, help="For distributed training: local_rank"
)
parser.add_argument("--fp16", type=bool, default=False)
parser.add_argument("--fp16_opt_level", type=str, default="O2")
args = parser.parse_args()
QUICK_RUN = True
OUTPUT_FILE = "./nlp_cnndm_finetuning_results.txt"
# model parameters
MODEL_NAME = "unilm-large-cased"
MAX_SEQ_LENGTH = 768
MAX_SOURCE_SEQ_LENGTH = 640
MAX_TARGET_SEQ_LENGTH = 128
# fine-tuning parameters
TRAIN_PER_GPU_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 3e-5
if QUICK_RUN:
TOP_N = 100
WARMUP_STEPS = 10
MAX_STEPS = 100
else:
TOP_N = -1
WARMUP_STEPS = 500
MAX_STEPS = 5000
# inference parameters
TEST_PER_GPU_BATCH_SIZE = 8
BEAM_SIZE = 5
FORBID_IGNORE_WORD = "."
train_ds = "train_ds.jsonl"
test_ds = "test_ds.jsonl"
def main():
torch.distributed.init_process_group(
timeout=datetime.timedelta(0, 5400), backend="nccl",
)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
processor = S2SAbsSumProcessor(model_name=MODEL_NAME)
abs_summarizer = S2SAbstractiveSummarizer(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,
max_target_seq_length=MAX_TARGET_SEQ_LENGTH,
)
if args.local_rank == 0:
torch.distributed.barrier()
train_dataset = processor.s2s_dataset_from_json_or_file(
train_ds, train_mode=True, local_rank=args.local_rank
)
test_dataset = processor.s2s_dataset_from_json_or_file(
test_ds, train_mode=False, local_rank=args.local_rank
)
abs_summarizer.fit(
train_dataset=train_dataset,
per_gpu_batch_size=TRAIN_PER_GPU_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
learning_rate=LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
max_steps=MAX_STEPS,
fp16=args.fp16,
fp16_opt_level=args.fp16_opt_level,
local_rank=args.local_rank,
save_model_to_dir=".",
)
torch.distributed.barrier()
if args.local_rank in [-1, 0]:
res = abs_summarizer.predict(
test_dataset=test_dataset,
per_gpu_batch_size=TEST_PER_GPU_BATCH_SIZE,
beam_size=BEAM_SIZE,
forbid_ignore_word=FORBID_IGNORE_WORD,
fp16=args.fp16,
)
for r in res[:5]:
print(r)
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
for line in res:
f.write(line + "\n")
tgt = []
with jsonlines.open(test_ds) as reader:
for item in reader:
tgt.append(item["tgt"])
for t in tgt[:5]:
print(t)
print(compute_rouge_python(cand=res, ref=tgt))
if __name__ == "__main__":
main()