Skip to content

Commit

Permalink
add reset
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk committed Jan 18, 2024
1 parent 8b14a7c commit 60eb45f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
17 changes: 16 additions & 1 deletion examples/game/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
send_player_input,
get_chat_msg,
get_suggests,
ResetException,
)

import gradio as gr
Expand Down Expand Up @@ -67,7 +68,11 @@ def start_game():
args.game_config = GAME_CONFIG
from main import main

main(args)
while True:
try:
main(args)
except ResetException:
print("重置成功")

with gr.Blocks() as demo:
# Users can select the interested exp
Expand Down Expand Up @@ -106,12 +111,21 @@ def start_game():
with gr.Column():
export_button = gr.Button("导出完整游戏记录")
export_output = gr.File(label="下载完整游戏记录", visible=False)
reset_button = gr.Button(
value="重置",
)

def send_message(msg):
send_player_input(msg)
send_chat_msg(msg, "你")
return ""

def send_reset_message():
global glb_history_chat
glb_history_chat = []
send_player_input("**Reset**")
return ""

def update_suggest():
msg, samples = get_suggests()
if msg is not None:
Expand All @@ -131,6 +145,7 @@ def update_suggest():

outputs = [chatbot, user_chat_bot_suggest]
send_button.click(send_message, user_chat_input, user_chat_input)
reset_button.click(send_reset_message)
export_button.click(export_chat_history, [], export_output)
user_chat_input.submit(send_message, user_chat_input, user_chat_input)
demo.load(get_chat, inputs=None, outputs=chatbot, every=0.5)
Expand Down
2 changes: 1 addition & 1 deletion examples/game/ruled_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def reply(
f"【请重试】",
"⚠️",
)
except Exception as e:
except UnicodeDecodeError as e:
send_chat_msg(f"【无效输入】 {e}\n 【请重试】", "⚠️")

kwargs = {}
Expand Down
10 changes: 10 additions & 0 deletions examples/game/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,15 @@ def send_pretty_msg(msg):


def get_player_input(name=None):
global glb_queue_chat_msg, glb_queue_chat_input, glb_queue_chat_suggests
if get_use_web_ui():
print("wait queue input")
content = glb_queue_chat_input.get(block=True)[1]
if content == "**Reset**":
glb_queue_chat_msg = Queue()
glb_queue_chat_input = Queue()
glb_queue_chat_suggests = Queue()
raise ResetException
else:
content = input(f"{name}: ")
return content
Expand Down Expand Up @@ -220,3 +226,7 @@ def end_query_answer():
class CheckpointArgs:
load_checkpoint: str = None
save_checkpoint: str = "./checkpoints/cp-"


class ResetException(Exception):
pass

0 comments on commit 60eb45f

Please sign in to comment.