diff --git a/examples/game/app.py b/examples/game/app.py index 2b62f65d8..a174cafda 100644 --- a/examples/game/app.py +++ b/examples/game/app.py @@ -13,6 +13,7 @@ send_player_input, get_chat_msg, get_suggests, + ResetException, ) import gradio as gr @@ -67,7 +68,11 @@ def start_game(): args.game_config = GAME_CONFIG from main import main - main(args) + while True: + try: + main(args) + except ResetException: + print("重置成功") with gr.Blocks() as demo: # Users can select the interested exp @@ -106,12 +111,21 @@ def start_game(): with gr.Column(): export_button = gr.Button("导出完整游戏记录") export_output = gr.File(label="下载完整游戏记录", visible=False) + reset_button = gr.Button( + value="重置", + ) def send_message(msg): send_player_input(msg) send_chat_msg(msg, "你") return "" + def send_reset_message(): + global glb_history_chat + glb_history_chat = [] + send_player_input("**Reset**") + return "" + def update_suggest(): msg, samples = get_suggests() if msg is not None: @@ -131,6 +145,7 @@ def update_suggest(): outputs = [chatbot, user_chat_bot_suggest] send_button.click(send_message, user_chat_input, user_chat_input) + reset_button.click(send_reset_message) export_button.click(export_chat_history, [], export_output) user_chat_input.submit(send_message, user_chat_input, user_chat_input) demo.load(get_chat, inputs=None, outputs=chatbot, every=0.5) diff --git a/examples/game/ruled_user.py b/examples/game/ruled_user.py index a0647d903..0e67a06e7 100644 --- a/examples/game/ruled_user.py +++ b/examples/game/ruled_user.py @@ -89,7 +89,7 @@ def reply( f"【请重试】", "⚠️", ) - except Exception as e: + except UnicodeDecodeError as e: send_chat_msg(f"【无效输入】 {e}\n 【请重试】", "⚠️") kwargs = {} diff --git a/examples/game/utils.py b/examples/game/utils.py index fa4444cfd..02033a73e 100644 --- a/examples/game/utils.py +++ b/examples/game/utils.py @@ -140,9 +140,15 @@ def send_pretty_msg(msg): def get_player_input(name=None): + global glb_queue_chat_msg, glb_queue_chat_input, glb_queue_chat_suggests if get_use_web_ui(): print("wait queue input") content = glb_queue_chat_input.get(block=True)[1] + if content == "**Reset**": + glb_queue_chat_msg = Queue() + glb_queue_chat_input = Queue() + glb_queue_chat_suggests = Queue() + raise ResetException else: content = input(f"{name}: ") return content @@ -220,3 +226,7 @@ def end_query_answer(): class CheckpointArgs: load_checkpoint: str = None save_checkpoint: str = "./checkpoints/cp-" + + +class ResetException(Exception): + pass