-
Notifications
You must be signed in to change notification settings - Fork 42
/
utils.py
129 lines (103 loc) · 4.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import torch
import json
import re
import os
import string
import time
def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def remove_citations(sent):
return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")
def get_max_memory():
"""Get the maximum memory available for the current GPU for loading models."""
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{free_in_GB-6}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
return max_memory
def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None):
# For doc prompt:
# - {ID}: doc id (starting from 1)
# - {T}: title
# - {P}: text
# use_shorter: None, "summary", or "extraction"
text = doc['text']
if use_shorter is not None:
text = doc[use_shorter]
return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1))
def get_shorter_text(item, docs, ndoc, key):
doc_list = []
for item_id, item in enumerate(docs):
if key not in item:
if len(doc_list) == 0:
# If there aren't any document, at least provide one (using full text)
item[key] = item['text']
doc_list.append(item)
logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.")
break
if "irrelevant" in item[key] or "Irrelevant" in item[key]:
continue
doc_list.append(item)
if len(doc_list) >= ndoc:
break
return doc_list
def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False):
# For demo prompt
# - {INST}: the instruction
# - {D}: the documents
# - {Q}: the question
# - {A}: the answers
# ndoc: number of documents to put in context
# use_shorter: None, "summary", or "extraction"
prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question'])
if "{D}" in prompt:
if ndoc == 0:
prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line
else:
doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc]
text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)])
prompt = prompt.replace("{D}", text)
if not test:
answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"]
prompt = prompt.replace("{A}", "").rstrip() + answer
else:
prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n
return prompt
def load_model(model_name_or_path, dtype=torch.float16, int8=False, reserve_memory=10):
# Load a huggingface model and tokenizer
# dtype: torch.float16 or torch.bfloat16
# int8: whether to use int8 quantization
# reserve_memory: how much memory to reserve for the model on each gpu (in GB)
# Load the FP16 model
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info(f"Loading {model_name_or_path} in {dtype}...")
if int8:
logger.warn("Use LLM.int8")
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map='auto',
torch_dtype=dtype,
max_memory=get_max_memory(),
load_in_8bit=int8,
)
logger.info("Finish loading in %.2f sec." % (time.time() - start_time))
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# Fix OPT bos token problem in HF
if "opt" in model_name_or_path:
tokenizer.bos_token = "<s>"
tokenizer.padding_side = "left"
return model, tokenizer