-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
130 lines (104 loc) · 4.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import tkinter as tk
from tkinter import font as tkfont
from tkinter import font as tkfont
import os
##from peft import PeftConfig, PeftModel
from peft import PeftConfig, PeftModel
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import pipeline
import pandas as pd
# locate current directory
current_directory = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_directory)
new_directory = os.getcwd()
# load database
covid_data = pd.read_csv("Bert/database_small.csv")
database = covid_data['context']
config = PeftConfig.from_pretrained("./Flan-T5-Lora/model/flan-t5-covid-lora")
model = T5ForConditionalGeneration.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, "./Flan-T5-Lora/model/flan-t5-covid-lora")
tokenizer = T5Tokenizer.from_pretrained(config.base_model_name_or_path)
# load flan-t5
# load bert
model_checkpoint = "./Bert/model/bert-COVID-QA"
question_answerer = pipeline("question-answering", model=model_checkpoint)
def get_answer(question):
if model_version.get() == 1: # use flan-t5
input_ids = tokenizer(question, return_tensors="pt").input_ids.to("cpu")
outputs = model.generate(input_ids=input_ids)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
model_response = "Answer from our model:" + answer
else: # use bert
score = 0
model_response = ''
for content in database:
res = question_answerer(question=question, context=content)
if res['score'] > score:
model_response = res['answer']
score = res['score']
return model_response
window=tk.Tk()
window.title("QA System")
window.geometry("1250x1200")
window.configure(bg='#f7f7f7')
# 设置字体
font_title = tkfont.Font(family="Arial Bold", size=20)
font_input = tkfont.Font(family="Arial", size=16)
font_response = tkfont.Font(family="Arial Bold", size=18)
history = []
# select model
model_version = tk.IntVar()
model_version.set(1) # default select flan-t5
lbl_version = tk.Label(window, text="Select Model:", font=(font_input.actual()['family'], font_input.actual()['size'], 'bold'),bg='#f7f7f7')
lbl_version.grid(column=0, row=0, pady=20, padx=20, sticky='w')
radio_flan = tk.Radiobutton(window, text="Flan-T5-base", variable=model_version, value=1, font=font_input, bg='#f7f7f7')
radio_flan.grid(column=0, row=1, pady=20, padx=20, sticky='w')
radio_bert = tk.Radiobutton(window, text="Bert", variable=model_version, value=2, font=font_input, bg='#f7f7f7')
radio_bert.grid(column=1, row=1, pady=20, padx=20, sticky='w')
# 1.定义输入
# 输入框
question_text = "Please input your question..."
txt_question = tk.Text(window, width=85, height=5, font=font_input, borderwidth=2, relief="solid", bg='#ffffff', wrap="word")
txt_question.grid(column=0, row=3, columnspan=2, pady=20, padx=20, sticky='nsew')
txt_question.insert(tk.END, question_text)
txt_question.config(foreground="#888888")
# 清除默认提示文字
def clear_default_text(event):
if txt_question.get("1.0", "end-1c") == question_text:
txt_question.delete("1.0", "end-1c")
txt_question.config(foreground="#000000")
txt_question.bind("<FocusIn>", clear_default_text)
# 滚动条
scrollbar = tk.Scrollbar(window, width=15, command=txt_question.yview)
scrollbar.grid(column=2, row=3, pady=20, sticky='ns')
txt_question['yscrollcommand'] = scrollbar.set
# 2.定义输出
lbl_response=tk.Label(window, text="", font=font_response, bg='#f7f7f7', wraplength=1000)
lbl_response.grid(column=0, row=4, columnspan=2,pady=20, padx=20, sticky='ew')
lbl_response.configure(justify='left')
# 创建一个变量来存储回答
response_text=tk.StringVar()
#定义下面的按钮要绑定的函数,传递输入的问题,以及模型回复的答案
def clicked():
question= txt_question.get("1.0", "end")
if question:
answer = get_answer(question)#这里要改,改成调用模型,传递在弹窗输入的question给flan模型
response_text.set(answer) # 更新回答变量
lbl_response.configure(text=response_text.get()) # 更新标签文本
# 创建一个新的输入框
txt_question.delete("1.0", tk.END) # 清除现有输入框的内容
# 滚动到最新的输出
lbl_response.update_idletasks()
window.update_idletasks()
lbl_response.see(tk.END)
else:
response_text.set("Please enter a valid question.")
#定义一个按钮调用clicked函数(以调用模型)
btn=tk.Button(window, text="Click Me", font=font_input, bg="#4CAF50", fg='#ffffff', borderwidth=2, relief="solid", command=clicked)
btn.grid(column=2, row=4, pady=20, padx=20, sticky='ew')
# 使列和行可适应窗口的大小变化
window.columnconfigure((0, 1), weight=1)
window.rowconfigure((2, 3, 4, 5), weight=1)
# 将问题的变量绑定到窗口的变量上,以便在函数中使用
window.question_input = txt_question
window.mainloop()