-
Notifications
You must be signed in to change notification settings - Fork 131
/
web_demo.py
154 lines (125 loc) · 6.61 KB
/
web_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
import gradio as gr
from PIL import Image
import os
import json
from model import is_chinese, generate_input, chat
import torch
import argparse
from transformers import AutoTokenizer
from model import VisualGLMModel, chat
from finetune_XrayGLM import FineTuneVisualGLMModel
from sat.model import AutoModel
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
input_para = {
"max_length": 2048,
"min_length": 50,
"temperature": 0.8,
"top_p": 0.4,
"top_k": 100,
"repetition_penalty": 1.2
}
input_para.update(request_data)
input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
return answer
def request_model(input_text, temperature, top_p, image_prompt, result_previous):
result_text = [(ele[0], ele[1]) for ele in result_previous]
for i in range(len(result_text)-1, -1, -1):
if result_text[i][0] == "" or result_text[i][1] == "":
del result_text[i]
print(f"history {result_text}")
is_zh = is_chinese(input_text)
if image_prompt is None:
if is_zh:
result_text.append((input_text, '图片为空!请上传图片并重试。'))
else:
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
return input_text, result_text
elif input_text == "":
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
return "", result_text
request_para = {"temperature": temperature, "top_p": top_p}
image = Image.open(image_prompt)
try:
answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
except Exception as e:
print(f"error: {e}")
if is_zh:
result_text.append((input_text, '超时!请稍等几分钟再重试。'))
else:
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
return "", result_text
result_text.append((input_text, answer))
print(result_text)
return "", result_text
DESCRIPTION = '''# <a href="https://github.com/WangRongsheng/XrayGLM">XRAY-GLM</a>'''
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
NOTES = 'This app is adapted from <a href="https://github.com/WangRongsheng/XrayGLM">https://github.com/WangRongsheng/XrayGLM</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
def clear_fn(value):
return "", [("", "Hi, What do you want to know about this image?")], None
def clear_fn2(value):
return [("", "Hi, What do you want to know about this image?")]
def main(args):
global model, tokenizer
# load model
model, model_args = AutoModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
fp16=True,
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
))
model = model.eval()
if args.quant:
quantize(model.transformer, args.quant)
if torch.cuda.is_available():
model = model.cuda()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=4.5):
with gr.Group():
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
with gr.Row():
run_button = gr.Button('Generate')
clear_button = gr.Button('Clear')
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
with gr.Row():
temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
with gr.Group():
with gr.Row():
maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
with gr.Column(scale=5.5):
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
gr.Markdown(NOTES)
print(gr.__version__)
run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
print(gr.__version__)
demo.queue(concurrency_count=10)
demo.launch(share=args.share)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
parser.add_argument("--share", action="store_true")
parser.add_argument("--from_pretrained", type=str, default="checkpoints", help='pretrained ckpt')
args = parser.parse_args()
main(args)