forked from ztxz16/fastllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
黄宇扬
committed
Jul 9, 2024
1 parent
f663549
commit c494755
Showing
5 changed files
with
162 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
from ftllm import llm | ||
import sys | ||
import os | ||
import argparse | ||
from util import make_normal_parser | ||
from util import make_normal_llm_model | ||
|
||
def parse_args(): | ||
parser = make_normal_parser("fastllm webui") | ||
parser.add_argument("--port", type = int, default = 8080, help = "API server port") | ||
parser.add_argument("--title", type = str, default = "fastllm webui", help = "页面标题") | ||
return parser.parse_args() | ||
|
||
args = parse_args() | ||
|
||
import streamlit as st | ||
from streamlit_chat import message | ||
st.set_page_config( | ||
page_title = args.title, | ||
page_icon = ":robot:" | ||
) | ||
|
||
@st.cache_resource | ||
def get_model(): | ||
args = parse_args() | ||
model = make_normal_llm_model(args) | ||
return model | ||
|
||
if "messages" not in st.session_state: | ||
st.session_state.messages = [] | ||
|
||
max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1) | ||
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01) | ||
top_k = st.sidebar.slider("top_k", 1, 100, 1, step = 1) | ||
temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step = 0.01) | ||
repeat_penalty = st.sidebar.slider("repeat_penalty", 1.0, 10.0, 1.0, step = 0.05) | ||
|
||
buttonClean = st.sidebar.button("清理会话历史", key="clean") | ||
if buttonClean: | ||
st.session_state.messages = [] | ||
st.rerun() | ||
|
||
for i, (prompt, response) in enumerate(st.session_state.messages): | ||
with st.chat_message("user"): | ||
st.markdown(prompt) | ||
with st.chat_message("assistant"): | ||
st.markdown(response) | ||
|
||
if prompt := st.chat_input("请开始对话"): | ||
model = get_model() | ||
with st.chat_message("user"): | ||
st.markdown(prompt) | ||
|
||
with st.chat_message("assistant"): | ||
message_placeholder = st.empty() | ||
full_response = "" | ||
for chunk in model.stream_response(prompt, | ||
st.session_state.messages, | ||
max_length = max_new_tokens, | ||
top_k = top_k, | ||
top_p = top_p, | ||
temperature = temperature, | ||
repeat_penalty = repeat_penalty, | ||
one_by_one = True): | ||
full_response += chunk | ||
message_placeholder.markdown(full_response + "▌") | ||
message_placeholder.markdown(full_response) | ||
st.session_state.messages.append((prompt, full_response)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
try: | ||
import streamlit as st | ||
except: | ||
print("Plase install streamlit-chat. (pip install streamlit-chat)") | ||
exit(0) | ||
|
||
import os | ||
import sys | ||
|
||
if __name__ == "__main__": | ||
current_path = os.path.dirname(os.path.abspath(__file__)) | ||
web_demo_path = os.path.join(current_path, 'web_demo.py') | ||
port = "" | ||
for i in range(len(sys.argv)): | ||
if sys.argv[i] == "--port": | ||
port = "--server.port " + sys.argv[i + 1] | ||
if sys.argv[i] == "--help" or sys.argv[i] == "-h": | ||
os.system("python3 " + web_demo_path + " --help") | ||
exit(0) | ||
os.system("streamlit run " + port + " " + web_demo_path + ' -- ' + ' '.join(sys.argv[1:])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters