From f25dd57e671e1f36640f23d60d206246cbf2b033 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 2 Jun 2024 20:00:17 -0500 Subject: [PATCH] make memory configurable, consistently truncate discord messages, fix action prompt --- client/src/models.ts | 2 +- client/src/prompt.tsx | 20 ++++++++++++++++---- taleweave/actions/base.py | 12 ++++++------ taleweave/bot/discord.py | 20 +++++++++++--------- taleweave/main.py | 2 +- taleweave/simulate.py | 26 ++++++++++++-------------- taleweave/state.py | 11 +++++++---- taleweave/systems/digest.py | 3 ++- taleweave/utils/prompt.py | 13 ++++++++----- 9 files changed, 64 insertions(+), 45 deletions(-) diff --git a/client/src/models.ts b/client/src/models.ts index 6144923..b475662 100644 --- a/client/src/models.ts +++ b/client/src/models.ts @@ -69,7 +69,7 @@ export interface StringParameter { export interface NumberParameter { type: 'number'; default?: number; - enum?: Array; + enum?: Array; } export type Parameter = BooleanParameter | NumberParameter | StringParameter; diff --git a/client/src/prompt.tsx b/client/src/prompt.tsx index 117b678..9aeba13 100644 --- a/client/src/prompt.tsx +++ b/client/src/prompt.tsx @@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World) } } -export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe): Parameter { +export function convertSignificantParameter(name: string, parameter: T, world: Maybe): T { if (parameter.type === 'boolean') { return parameter; } @@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record `${name}=${value}`).join(',')}`; } +export function getEnumOrDefault(defaultValue: Maybe, enumValues: Maybe>, evenMoreDefault: T): T { + if (doesExist(defaultValue)) { + return defaultValue; + } + + if (doesExist(enumValues)) { + return enumValues[0]; + } + + return evenMoreDefault; +} + export function makeDefaultParameterValues(parameters: Record) { return Object.entries(parameters).reduce((acc, [name, parameter]) => { switch (parameter.type) { case 'boolean': - return { ...acc, [name]: mustDefault(parameter.default, false) }; + return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) }; case 'number': - return { ...acc, [name]: mustDefault(parameter.default, 0) }; + return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) }; case 'string': - return { ...acc, [name]: mustDefault(parameter.default, '') }; + return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') }; default: return acc; } diff --git a/taleweave/actions/base.py b/taleweave/actions/base.py index a852fee..4d757e0 100644 --- a/taleweave/actions/base.py +++ b/taleweave/actions/base.py @@ -5,6 +5,7 @@ broadcast, get_agent_for_character, get_character_agent_for_name, + get_game_config, get_prompt, world_context, ) @@ -22,8 +23,6 @@ logger = getLogger(__name__) -MAX_CONVERSATION_STEPS = 2 - def action_examine(target: str) -> str: """ @@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str: character: The name of the character to ask. You cannot ask yourself questions. question: The question to ask them. """ - # capture references to the current character and room, because they will be overwritten + config = get_game_config() + with action_context() as (action_room, action_character): # sanity checks question_character, question_agent = get_character_agent_for_name(character) @@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str: end_prompt, echo_function=action_tell.__name__, echo_parameter="message", - max_length=MAX_CONVERSATION_STEPS, + max_length=config.world.character.conversation_limit, ) if result: @@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str: character: The name of the character to tell. You cannot talk to yourself. message: The message to tell them. """ - # capture references to the current character and room, because they will be overwritten + config = get_game_config() with action_context() as (action_room, action_character): # sanity checks @@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str: end_prompt, echo_function=action_tell.__name__, echo_parameter="message", - max_length=MAX_CONVERSATION_STEPS, + max_length=config.world.character.conversation_limit, ) if result: diff --git a/taleweave/bot/discord.py b/taleweave/bot/discord.py index d6802e1..4fd133a 100644 --- a/taleweave/bot/discord.py +++ b/taleweave/bot/discord.py @@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent): event_messages[event_message.id] = message +def truncate(text: str, length: int = 1000) -> str: + if len(text) > length: + return text[:length] + "..." + return text + + def embed_from_event(event: GameEvent) -> Embed | None: if isinstance(event, GenerateEvent): return embed_from_generate(event) @@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent): def embed_from_reply(event: ReplyEvent): reply_embed = Embed(title=event.room.name, description=event.speaker.name) - reply_embed.add_field(name="Reply", value=event.text) + reply_embed.add_field(name="Reply", value=truncate(event.text)) return reply_embed @@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed: def embed_from_result(event: ResultEvent): - text = event.result - if len(text) > 1000: - text = text[:1000] + "..." - result_embed = Embed(title=event.room.name, description=event.character.name) - result_embed.add_field(name="Result", value=text) + result_embed.add_field(name="Result", value=truncate(event.result)) return result_embed @@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent): title = format_prompt("discord_leave_title", event=event) description = format_prompt("discord_leave_result", event=event) - player_embed = Embed(title=title, description=description) + player_embed = Embed(title=title, description=truncate(description)) return player_embed def embed_from_prompt(event: PromptEvent): # TODO: ping the player prompt_embed = Embed(title=event.room.name, description=event.character.name) - prompt_embed.add_field(name="Prompt", value=event.prompt) + prompt_embed.add_field(name="Prompt", value=truncate(event.prompt)) return prompt_embed @@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent): title=event.room.name if event.room else "", description=event.character.name if event.character else "", ) - status_embed.add_field(name="Status", value=event.text) + status_embed.add_field(name="Status", value=truncate(event.text)) return status_embed diff --git a/taleweave/main.py b/taleweave/main.py index 8127cfa..bd439df 100644 --- a/taleweave/main.py +++ b/taleweave/main.py @@ -416,7 +416,7 @@ def snapshot_system(world: World, turn: int, data: None = None) -> None: set_dungeon_master(world_builder) # start the sim - logger.debug("simulating world: %s", world) + logger.debug("simulating world: %s", world.name) simulate_world( world, turns=args.turns, diff --git a/taleweave/simulate.py b/taleweave/simulate.py index fbdcccc..706b8c7 100644 --- a/taleweave/simulate.py +++ b/taleweave/simulate.py @@ -44,6 +44,7 @@ set_current_world, set_game_systems, ) +from taleweave.errors import ActionError from taleweave.game_system import GameSystem from taleweave.models.entity import Character, Room, World from taleweave.models.event import ActionEvent, ResultEvent @@ -117,12 +118,9 @@ def result_parser(value, **kwargs): # TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice event = ActionEvent.from_json(value, room, character) else: - # TODO: this path should be removed and throw - # logger.warning( - # "invalid action, emitting as result event - this is a bug somewhere" - # ) - # event = ResultEvent(value, room, character) - raise ValueError("invalid non-JSON action") + raise ActionError( + "Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format." + ) broadcast(event) @@ -216,14 +214,14 @@ def prompt_character_planning( while not stop_condition(current=i): result = loop_retry( agent, - get_prompt("world_simulate_character_planning"), - context={ - "event_count": event_count, - "events_prompt": events_prompt, - "note_count": note_count, - "notes_prompt": notes_prompt, - "room_summary": summarize_room(room, character), - }, + format_prompt( + "world_simulate_character_planning", + event_count=event_count, + events_prompt=events_prompt, + note_count=note_count, + notes_prompt=notes_prompt, + room_summary=summarize_room(room, character), + ), result_parser=result_parser, stop_condition=stop_condition, toolbox=planner_toolbox, diff --git a/taleweave/state.py b/taleweave/state.py index 04826bd..7be9353 100644 --- a/taleweave/state.py +++ b/taleweave/state.py @@ -7,12 +7,14 @@ from packit.agent import Agent, agent_easy_connect from pydantic import RootModel -from taleweave.context import get_all_character_agents, set_character_agent +from taleweave.context import ( + get_all_character_agents, + get_game_config, + set_character_agent, +) from taleweave.models.entity import World from taleweave.player import LocalPlayer -MEMORY_LIMIT = 25 # 10 - def create_agents( world: World, @@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int): def restore_memory( data: Sequence[str | Dict[str, str]] ) -> deque[str | AIMessage | HumanMessage | SystemMessage]: + config = get_game_config() memories = [] for memory in data: @@ -85,7 +88,7 @@ def restore_memory( elif memory_type == "ai": memories.append(AIMessage(content=memory_content)) - return deque(memories, maxlen=MEMORY_LIMIT) + return deque(memories, maxlen=config.world.character.memory_limit) def save_world(world, filename): diff --git a/taleweave/systems/digest.py b/taleweave/systems/digest.py index bd3462d..393bfbe 100644 --- a/taleweave/systems/digest.py +++ b/taleweave/systems/digest.py @@ -5,6 +5,7 @@ from taleweave.game_system import FormatPerspective, GameSystem from taleweave.models.entity import Character, Room, World, WorldEntity from taleweave.models.event import ActionEvent, GameEvent +from taleweave.utils.prompt import format_str from taleweave.utils.search import find_containing_room logger = getLogger(__name__) @@ -22,7 +23,7 @@ def create_turn_digest( if prompt_key in library.prompts: try: template = library.prompts[prompt_key] - message = template.format(event=event) + message = format_str(template, event=event) messages.append(message) except Exception: logger.exception("error formatting digest event: %s", event) diff --git a/taleweave/utils/prompt.py b/taleweave/utils/prompt.py index c01141a..2dd988c 100644 --- a/taleweave/utils/prompt.py +++ b/taleweave/utils/prompt.py @@ -3,10 +3,17 @@ from jinja2 import Environment from taleweave.context import get_prompt_library +from taleweave.utils.string import and_list, or_list from taleweave.utils.world import describe_entity, name_entity logger = getLogger(__name__) +jinja_env = Environment() +jinja_env.filters["describe"] = describe_entity +jinja_env.filters["name"] = name_entity +jinja_env.filters["and_list"] = and_list +jinja_env.filters["or_list"] = or_list + def format_prompt(prompt_key: str, **kwargs) -> str: try: @@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str: def format_str(template_str: str, **kwargs) -> str: - env = Environment() - env.filters["describe"] = describe_entity - env.filters["name"] = name_entity - - template = env.from_string(template_str) + template = jinja_env.from_string(template_str) return template.render(**kwargs)