From 2823624769c64abee6fb379bac9e6c484acf7f61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 23 Oct 2024 19:07:47 +0800 Subject: [PATCH] =?UTF-8?q?webui=E6=94=AF=E6=8C=81=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E8=BE=93=E5=85=A5=EF=BC=88=E4=BB=85=E5=A4=9A=E6=A8=A1=E6=80=81?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8F=AF=E7=94=A8=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements-server.txt | 3 ++- tools/fastllm_pytools/web_demo.py | 45 +++++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/requirements-server.txt b/requirements-server.txt index 4e85afa1..ae9932f4 100644 --- a/requirements-server.txt +++ b/requirements-server.txt @@ -1,4 +1,5 @@ fastapi pydantic openai -shortuuid \ No newline at end of file +shortuuid +unicorn \ No newline at end of file diff --git a/tools/fastllm_pytools/web_demo.py b/tools/fastllm_pytools/web_demo.py index bd7500c8..2b48d975 100644 --- a/tools/fastllm_pytools/web_demo.py +++ b/tools/fastllm_pytools/web_demo.py @@ -30,6 +30,8 @@ def get_model(): if "messages" not in st.session_state: st.session_state.messages = [] +if "images" not in st.session_state: + st.session_state.images = [] system_prompt = st.sidebar.text_input("system_prompt", "") max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1) @@ -43,12 +45,35 @@ def get_model(): st.session_state.messages = [] st.rerun() +# 添加文件输入部件 +if (uploaded_file := st.file_uploader("上传图片", type=["jpg", "jpeg", "png"])) is not None: + from PIL import Image + import io + import base64 + + image = Image.open(uploaded_file).convert('RGB') + st.session_state.images = [image] + + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + + # 创建一个包含图片的 HTML 元素,并设置样式 + image_html = f""" +
+ +
+ """ + + # 显示图片 + st.markdown(image_html, unsafe_allow_html=True) + for i, (prompt, response) in enumerate(st.session_state.messages): with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): st.markdown(response) - + if prompt := st.chat_input("请开始对话"): model = get_model() with st.chat_message("user"): @@ -65,14 +90,24 @@ def get_model(): messages.append({"role": "assistant", "content": his[1]}) messages.append({"role": "user", "content": prompt}) - for chunk in model.stream_response(messages, + if (len(st.session_state.images) > 0): + handle = model.launch_stream_response(messages, + max_length = max_new_tokens, do_sample = True, + top_p = top_p, top_k = top_k, temperature = temperature, + repeat_penalty = repeat_penalty, one_by_one = True, images = st.session_state.images) + for chunk in model.stream_response_handle(handle): + full_response += chunk + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + else: + for chunk in model.stream_response(messages, max_length = max_new_tokens, top_k = top_k, top_p = top_p, temperature = temperature, repeat_penalty = repeat_penalty, one_by_one = True): - full_response += chunk - message_placeholder.markdown(full_response + "▌") - message_placeholder.markdown(full_response) + full_response += chunk + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) st.session_state.messages.append((prompt, full_response))