Skip to content

Commit

Permalink
webui支持图片输入(仅多模态模型可用)
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Oct 23, 2024
1 parent d779940 commit 2823624
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
3 changes: 2 additions & 1 deletion requirements-server.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fastapi
pydantic
openai
shortuuid
shortuuid
unicorn
45 changes: 40 additions & 5 deletions tools/fastllm_pytools/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def get_model():

if "messages" not in st.session_state:
st.session_state.messages = []
if "images" not in st.session_state:
st.session_state.images = []

system_prompt = st.sidebar.text_input("system_prompt", "")
max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1)
Expand All @@ -43,12 +45,35 @@ def get_model():
st.session_state.messages = []
st.rerun()

# 添加文件输入部件
if (uploaded_file := st.file_uploader("上传图片", type=["jpg", "jpeg", "png"])) is not None:
from PIL import Image
import io
import base64

image = Image.open(uploaded_file).convert('RGB')
st.session_state.images = [image]

buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()

# 创建一个包含图片的 HTML 元素,并设置样式
image_html = f"""
<div style="display: flex; justify-content: flex-end;">
<img src="data:image/png;base64,{img_str}" style="max-width: 300px; max-height: 300px;">
</div>
"""

# 显示图片
st.markdown(image_html, unsafe_allow_html=True)

for i, (prompt, response) in enumerate(st.session_state.messages):
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
st.markdown(response)

if prompt := st.chat_input("请开始对话"):
model = get_model()
with st.chat_message("user"):
Expand All @@ -65,14 +90,24 @@ def get_model():
messages.append({"role": "assistant", "content": his[1]})
messages.append({"role": "user", "content": prompt})

for chunk in model.stream_response(messages,
if (len(st.session_state.images) > 0):
handle = model.launch_stream_response(messages,
max_length = max_new_tokens, do_sample = True,
top_p = top_p, top_k = top_k, temperature = temperature,
repeat_penalty = repeat_penalty, one_by_one = True, images = st.session_state.images)
for chunk in model.stream_response_handle(handle):
full_response += chunk
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
else:
for chunk in model.stream_response(messages,
max_length = max_new_tokens,
top_k = top_k,
top_p = top_p,
temperature = temperature,
repeat_penalty = repeat_penalty,
one_by_one = True):
full_response += chunk
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
full_response += chunk
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
st.session_state.messages.append((prompt, full_response))

0 comments on commit 2823624

Please sign in to comment.