-
Notifications
You must be signed in to change notification settings - Fork 1
/
gpt_xml_helpers.py
209 lines (169 loc) · 6.56 KB
/
gpt_xml_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import openai
import tiktoken
import io
import os
import time
import xml.etree.ElementTree as ET
from typing import Dict
from pathlib import Path
"""
Adapted but mostly the same from Weixin-Liang et al.
https://github.com/Weixin-Liang/LLM-scientific-feedback/blob/main/main.py
"""
##############################################
# XML PROCESSING/PARSING
##############################################
def extract_element_text(element):
if element.text:
text = element.text
else:
text = " "
for child in element:
text += " " + extract_element_text(child)
if child.tail:
text += " " + child.tail
return text
def get_article_title(root):
article_title = root.find(".//article-title")
if article_title is not None:
title_text = article_title.text
return title_text
else:
return "Artitle Title" # not found
def get_abstract(root):
# find the abstract element and print its text content
abstract = root.find(".//abstract/p")
if abstract is not None:
return abstract.text
abstract = root.find(".//sec[title='Abstract']")
if abstract is not None:
return extract_element_text(abstract)
return "Abstract" # not found
def get_section_text(root, section_title="Introduction"):
"""
Warning: if introduction has subsection, it's another XML section.
Extracts the text content of a section with the given title from the given root element.
:param root: The root element of an XML document.
:param section_title: The title of the section to extract. Case-insensitive.
:return: The text content of the section as a string.
"""
section = None
for sec in root.findall(".//sec"):
title_elem = sec.find("title")
if title_elem is not None and title_elem.text.lower() == section_title.lower():
section = sec
break
# If no matching section is found, return an empty string
if section is None:
return ""
return extract_element_text(section)
def get_figure_and_table_captions(root):
"""
Extracts all figure and table captions from the given root element and returns them as a concatenated string.
"""
captions = []
# Get Figures section
figures = root.find('.//sec[title="Figures"]')
if figures is not None:
# Print Figures section content
for child in figures:
if child.tag == "fig":
title = child.find("caption/title")
caption = child.find("caption/p")
if title is not None and title.text is not None:
title_text = title.text.strip()
else:
title_text = ""
if caption is not None and caption.text is not None:
caption_text = caption.text.strip()
else:
caption_text = ""
captions.append(f"{title_text} {caption_text}")
# Print all table contents
table_wraps = root.findall(".//table-wrap")
if table_wraps is not None:
for table_wrap in table_wraps:
title = table_wrap.find("caption/title")
caption = table_wrap.find("caption/p")
if title is not None and title.text is not None:
title_text = title.text.strip()
else:
title_text = ""
if caption is not None and caption.text is not None:
caption_text = caption.text.strip()
else:
caption_text = ""
captions.append(f"{title_text} {caption_text}")
return "\n".join(captions)
def get_main_content(root):
"""
Get the main content of the paper, excluding the figures and tables section, usually no abstract too.
Args:
root: root of the xml file
Returns:
main_content_str: string of the main content of the paper
"""
main_content_str = ""
# Get all section elements
sections = root.findall(".//sec")
for sec in sections: # Exclude the figures section
# Get the section title if available
title = sec.find("title")
# Exclude Figures section
if title is not None and (title.text == "Figures"):
continue
elif title is not None:
main_content_str += f"\nSection Title: {title.text}\n" # Yes, title will duplicate with extract_element_text
main_content_str += extract_element_text(sec)
main_content_str += "\n"
return main_content_str
##############################################
# GPT Wrapper
# have "key.txt" with api key on project folder
##############################################
class GPTWrapper:
def __init__(self, model_name="gpt-3.5-turbo-1106", api_key=None):
self.model_name = model_name
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
openai.api_key = api_key or self._load_api_key
@property
def _load_api_key(self):
try:
return open("key.txt").read().strip()
except FileNotFoundError:
raise ValueError("API key file not found. Please provide a valid API key.")
def make_query_args(self, user_str, n_query=1):
system_message = {
"role": "system",
"content": "You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible.",
}
user_message = {"role": "user", "content": user_str}
query_args = {
"model": self.model_name,
"messages": [system_message, user_message],
"n": n_query,
}
return query_args
def compute_num_tokens(self, user_str: str) -> int:
return len(self.tokenizer.encode(user_str))
def send_query(self, user_str, n_query=1):
num_tokens = self.compute_num_tokens(user_str)
print(f"# tokens sent to GPT: {num_tokens}")
query_args = self.make_query_args(user_str, n_query)
try:
completion = openai.ChatCompletion.create(**query_args)
result = completion.choices[0]["message"]["content"]
return result
except openai.error.OpenAIError as e:
print(f"Error in send_query: {e}")
return "Error in processing the query."
# example usage
# wrapper = GPT4Wrapper(model_name="gpt-4")
def truncate(input_text: str, max_tokens: int, wrapper) -> str:
truncated_text = wrapper.tokenizer.decode(
wrapper.tokenizer.encode(input_text)[:max_tokens]
)
# Add back the closing ``` if it was truncated
if not truncated_text.endswith("```"):
truncated_text += "\n```"
return truncated_text