Skip to content

Commit

Permalink
Add support for multi-user (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Jan 19, 2024
1 parent 80562d0 commit cf6802f
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 98 deletions.
101 changes: 65 additions & 36 deletions examples/game/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
import yaml
import datetime
import uuid
import threading
from collections import defaultdict

import agentscope

Expand All @@ -21,7 +24,14 @@

enable_web_ui()

glb_history_chat = []

def init_uid_list():
return []


glb_history_dict = defaultdict(init_uid_list)
glb_signed_user = []
is_init = False
MAX_NUM_DISPLAY_MSG = 20

import base64
Expand Down Expand Up @@ -55,47 +65,59 @@ def format_cover_html(config: dict, bot_avatar_path='assets/bg.png'):
</div>
"""

def export_chat_history():
def export_chat_history(uid):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
export_filename = f"chat_history_{timestamp}.txt"

with open(export_filename, "w", encoding="utf-8") as file:
for role, message in glb_history_chat:
for role, message in glb_history_dict[uid]:
file.write(f"{role}: {message}\n")

return gr.update(value=export_filename, visible=True)


def get_chat() -> List[List]:
def get_chat(uid) -> List[List]:
"""Load the chat info from the queue, and put it into the history
Returns:
`List[List]`: The parsed history, list of tuple, [(role, msg), ...]
"""
global glb_history_chat
line = get_chat_msg()
global glb_history_dict
line = get_chat_msg(uid=uid)
if line is not None:
glb_history_chat += [line]

return glb_history_chat[-MAX_NUM_DISPLAY_MSG:]
glb_history_dict[uid] += [line]
return glb_history_dict[uid][-MAX_NUM_DISPLAY_MSG:]


if __name__ == "__main__":

def start_game():
def init_game():
global is_init
if not is_init:
TONGYI_CONFIG = {
"type": "tongyi",
"name": "tongyi_model",
"model_name": "qwen-max-1201",
"api_key": os.environ.get("TONGYI_API_KEY"),
}
agentscope.init(model_configs=[TONGYI_CONFIG], logger_level="INFO")
is_init = True

def check_for_new_session(uid):
print(uid)
if uid not in glb_signed_user:
glb_signed_user.append(uid)
game_thread = threading.Thread(target=start_game, args=(uid,))
game_thread.start()

def start_game(uid):
with open("./config/game_config.yaml", "r", encoding="utf-8") as file:
GAME_CONFIG = yaml.safe_load(file)
TONGYI_CONFIG = {
"type": "tongyi",
"name": "tongyi_model",
"model_name": "qwen-max-1201",
"api_key": os.environ.get("TONGYI_API_KEY"),
}

agentscope.init(model_configs=[TONGYI_CONFIG], logger_level="INFO")
args = CheckpointArgs()
args.game_config = GAME_CONFIG
args.uid = uid
from main import main

while True:
Expand All @@ -106,6 +128,7 @@ def start_game():

with gr.Blocks(css='assets/app.css') as demo:
# Users can select the interested exp
uuid = gr.State(uuid.uuid4)

welcome = {
"name": "饮食男女",
Expand Down Expand Up @@ -160,24 +183,24 @@ def start_game():
export_button = gr.Button("导出完整游戏记录")
export_output = gr.File(label="下载完整游戏记录", visible=False)

def send_message(msg, uid):
send_player_input(msg, uid=uid)
send_chat_msg(msg, "你", uid=uid)
return ""

return_welcome_button = gr.Button(
value="↩️返回首页",
visible=False,
)

def send_message(msg):
send_player_input(msg)
send_chat_msg(msg, "你")
return ""

def send_reset_message():
global glb_history_chat
glb_history_chat = []
send_player_input("**Reset**")
def send_reset_message(uid):
global glb_history_dict
glb_history_dict[uid] = init_uid_list()
send_player_input("**Reset**", uid=uid)
return ""

def update_suggest():
msg, samples = get_suggests()
def update_suggest(uid):
msg, samples = get_suggests(uid)
if msg is not None:
return gr.Dataset(
label=msg,
Expand Down Expand Up @@ -228,25 +251,31 @@ def welcome_ui():
outputs = [chatbot, user_chat_input, send_button, new_button, resume_button,return_welcome_button, user_chat_bot_suggest, export, user_chat_bot_cover]

# submit message
send_button.click(send_message, user_chat_input, user_chat_input)
user_chat_input.submit(send_message, user_chat_input, user_chat_input)
send_button.click(send_message, [user_chat_input, uuid], user_chat_input)
user_chat_input.submit(send_message, [user_chat_input, uuid], user_chat_input)

# change ui
new_button.click(game_ui, outputs=outputs)
resume_button.click(game_ui, outputs=outputs)
return_welcome_button.click(welcome_ui, outputs=outputs)

# start game
new_button.click(send_reset_message)
new_button.click(start_game)
resume_button.click(start_game)
new_button.click(send_reset_message, inputs=[uuid])
resume_button.click(check_for_new_session, inputs=[uuid])

# export
export_button.click(export_chat_history, [], export_output)
export_button.click(export_chat_history, [uuid], export_output)

# update chat history
demo.load(get_chat, outputs=chatbot, every=0.5)
demo.load(update_suggest, outputs=user_chat_bot_suggest, every=0.5)
demo.load(init_game)
demo.load(check_for_new_session, inputs=[uuid], every=0.1)
demo.load(get_chat, inputs=[uuid], outputs=chatbot, every=0.5)
demo.load(
update_suggest,
inputs=[uuid],
outputs=user_chat_bot_suggest,
every=0.5,
)

demo.queue()
demo.launch()
11 changes: 8 additions & 3 deletions examples/game/customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class Customer(StateAgent, DialogAgent):
def __init__(self, game_config: dict, **kwargs: Any):
self.uid = kwargs.pop("uid")
super().__init__(**kwargs)
self.game_config = game_config
self.max_itr_preorder = 5
Expand Down Expand Up @@ -109,6 +110,7 @@ def _default_score(_: str) -> float:
f"【系统】{self.name}: 好感度变化 "
f"{change_symbol}{change_in_friendship} "
f"当前好感度为 {self.friendship}",
uid=self.uid,
)

if (
Expand Down Expand Up @@ -211,7 +213,7 @@ def refine_background(self) -> None:
)

analysis = self.model(messages=prompt)
send_chat_msg(f"聊完之后,{self.name}在想:" + analysis)
send_chat_msg(f"聊完之后,{self.name}在想:" + analysis, uid=self.uid)

update_prompt = self.game_config["update_background"].format_map(
{
Expand All @@ -222,7 +224,10 @@ def refine_background(self) -> None:
)
update_msg = Msg(role="user", name="system", content=update_prompt)
new_background = self.model(messages=[update_msg])
send_chat_msg(f"根据对话,{self.name}的背景更新为:" + new_background)
send_chat_msg(
f"根据对话,{self.name}的背景更新为:" + new_background,
uid=self.uid,
)
self.background = new_background

def _validated_history_messages(self, recent_n: int = 10):
Expand Down Expand Up @@ -253,7 +258,7 @@ def generate_pov_story(self, recent_n: int = 20):
msg = Msg(name="system", role="user", content=pov_prompt)
pov_story = self.model(messages=[msg])
print("*" * 20)
send_chat_msg(pov_story)
send_chat_msg(pov_story, uid=self.uid)
print("*" * 20)

def _gen_plot_related_prompt(self) -> str:
Expand Down
Loading

0 comments on commit cf6802f

Please sign in to comment.