From c494755899015108d90757aef9523b63839d5a60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Tue, 9 Jul 2024 19:08:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0webui?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++ README_EN.md | 4 ++ tools/fastllm_pytools/web_demo.py | 69 ++++++++++++++++++++++++++++++ tools/fastllm_pytools/webui.py | 20 +++++++++ tools/scripts/web_demo.py | 71 ++++++++++++++++++++++++++++--- 5 files changed, 162 insertions(+), 6 deletions(-) create mode 100644 tools/fastllm_pytools/web_demo.py create mode 100644 tools/fastllm_pytools/webui.py diff --git a/README.md b/README.md index 9f094313..e14c2151 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,10 @@ python3 -m ftllm.chat -t 16 -p ~/Qwen2-7B-Instruct/ --dtype int8 # 需要安装依赖: pip install -r requirements-server.txt # 这里在8080端口打开了一个模型名为qwen的server python3 -m ftllm.server -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080 --model_name qwen + +# webui +# 需要安装依赖: pip install streamlit-chat +python3 -m ftllm.webui -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080 ``` 以上demo均可使用参数 --help 查看详细参数 diff --git a/README_EN.md b/README_EN.md index e89a7fb0..8c8b24be 100644 --- a/README_EN.md +++ b/README_EN.md @@ -54,6 +54,10 @@ python3 -m ftllm.chat -t 16 -p ~/Qwen2-7B-Instruct/ --dtype int8 # Requires dependencies: pip install -r requirements-server.txt # Opens a server named 'qwen' on port 8080 python3 -m ftllm.server -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080 --model_name qwen + +# webui +# Requires dependencies: pip install streamlit-chat +python3 -m ftllm.webui -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080 ``` Detailed parameters can be viewed using the --help argument for all demos. diff --git a/tools/fastllm_pytools/web_demo.py b/tools/fastllm_pytools/web_demo.py new file mode 100644 index 00000000..9b1da1e9 --- /dev/null +++ b/tools/fastllm_pytools/web_demo.py @@ -0,0 +1,69 @@ + +from ftllm import llm +import sys +import os +import argparse +from util import make_normal_parser +from util import make_normal_llm_model + +def parse_args(): + parser = make_normal_parser("fastllm webui") + parser.add_argument("--port", type = int, default = 8080, help = "API server port") + parser.add_argument("--title", type = str, default = "fastllm webui", help = "页面标题") + return parser.parse_args() + +args = parse_args() + +import streamlit as st +from streamlit_chat import message +st.set_page_config( + page_title = args.title, + page_icon = ":robot:" +) + +@st.cache_resource +def get_model(): + args = parse_args() + model = make_normal_llm_model(args) + return model + +if "messages" not in st.session_state: + st.session_state.messages = [] + +max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1) +top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01) +top_k = st.sidebar.slider("top_k", 1, 100, 1, step = 1) +temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step = 0.01) +repeat_penalty = st.sidebar.slider("repeat_penalty", 1.0, 10.0, 1.0, step = 0.05) + +buttonClean = st.sidebar.button("清理会话历史", key="clean") +if buttonClean: + st.session_state.messages = [] + st.rerun() + +for i, (prompt, response) in enumerate(st.session_state.messages): + with st.chat_message("user"): + st.markdown(prompt) + with st.chat_message("assistant"): + st.markdown(response) + +if prompt := st.chat_input("请开始对话"): + model = get_model() + with st.chat_message("user"): + st.markdown(prompt) + + with st.chat_message("assistant"): + message_placeholder = st.empty() + full_response = "" + for chunk in model.stream_response(prompt, + st.session_state.messages, + max_length = max_new_tokens, + top_k = top_k, + top_p = top_p, + temperature = temperature, + repeat_penalty = repeat_penalty, + one_by_one = True): + full_response += chunk + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + st.session_state.messages.append((prompt, full_response)) diff --git a/tools/fastllm_pytools/webui.py b/tools/fastllm_pytools/webui.py new file mode 100644 index 00000000..e8ce0ef5 --- /dev/null +++ b/tools/fastllm_pytools/webui.py @@ -0,0 +1,20 @@ +try: + import streamlit as st +except: + print("Plase install streamlit-chat. (pip install streamlit-chat)") + exit(0) + +import os +import sys + +if __name__ == "__main__": + current_path = os.path.dirname(os.path.abspath(__file__)) + web_demo_path = os.path.join(current_path, 'web_demo.py') + port = "" + for i in range(len(sys.argv)): + if sys.argv[i] == "--port": + port = "--server.port " + sys.argv[i + 1] + if sys.argv[i] == "--help" or sys.argv[i] == "-h": + os.system("python3 " + web_demo_path + " --help") + exit(0) + os.system("streamlit run " + port + " " + web_demo_path + ' -- ' + ' '.join(sys.argv[1:])) \ No newline at end of file diff --git a/tools/scripts/web_demo.py b/tools/scripts/web_demo.py index a8f78c73..8449f3a3 100644 --- a/tools/scripts/web_demo.py +++ b/tools/scripts/web_demo.py @@ -1,21 +1,73 @@ -import streamlit as st -from streamlit_chat import message + from ftllm import llm import sys +import os +import argparse + +def make_normal_parser(des: str) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description = des) + parser.add_argument('-p', '--path', type = str, required = True, default = '', help = '模型路径,fastllm模型文件或HF模型文件夹') + parser.add_argument('-t', '--threads', type = int, default = 4, help = '线程数量') + parser.add_argument('-l', '--low', action = 'store_true', help = '是否使用低内存模式') + parser.add_argument('--dtype', type = str, default = "float16", help = '权重类型(读取HF模型时有效)') + parser.add_argument('--atype', type = str, default = "float32", help = '推理类型,可使用float32或float16') + parser.add_argument('--cuda_embedding', action = 'store_true', help = '在cuda上进行embedding') + parser.add_argument('--device', type = str, help = '使用的设备') + return parser + +def parse_args(): + parser = make_normal_parser("fastllm webui") + parser.add_argument("--port", type = int, default = 8080, help = "API server port") + parser.add_argument("--title", type = str, default = "fastllm webui", help = "页面标题") + return parser.parse_args() + +def make_normal_llm_model(args): + if (args.device and args.device != ""): + try: + import ast + device_map = ast.literal_eval(args.device) + if (isinstance(device_map, list) or isinstance(device_map, dict)): + llm.set_device_map(device_map) + else: + llm.set_device_map(args.device) + except: + llm.set_device_map(args.device) + llm.set_cpu_threads(args.threads) + llm.set_cpu_low_mem(args.low) + if (args.cuda_embedding): + llm.set_cuda_embedding(True) + model = llm.model(args.path, dtype = args.dtype, tokenizer_type = "auto") + model.set_atype(args.atype) + return model +args = parse_args() +import streamlit as st +from streamlit_chat import message st.set_page_config( - page_title="fastllm web demo", - page_icon=":robot:" + page_title = args.title, + page_icon = ":robot:" ) @st.cache_resource def get_model(): - model = llm.model(sys.argv[1]) + args = parse_args() + model = make_normal_llm_model(args) return model if "messages" not in st.session_state: st.session_state.messages = [] +max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1) +top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01) +top_k = st.sidebar.slider("top_k", 1, 100, 1, step = 1) +temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step = 0.01) +repeat_penalty = st.sidebar.slider("repeat_penalty", 1.0, 10.0, 1.0, step = 0.05) + +buttonClean = st.sidebar.button("清理会话历史", key="clean") +if buttonClean: + st.session_state.messages = [] + st.rerun() + for i, (prompt, response) in enumerate(st.session_state.messages): with st.chat_message("user"): st.markdown(prompt) @@ -30,7 +82,14 @@ def get_model(): with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" - for chunk in model.stream_response(prompt, st.session_state.messages, one_by_one = True): + for chunk in model.stream_response(prompt, + st.session_state.messages, + max_length = max_new_tokens, + top_k = top_k, + top_p = top_p, + temperature = temperature, + repeat_penalty = repeat_penalty, + one_by_one = True): full_response += chunk message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response)