From d0ddeca2113e2a35cfbe8a164c958d5756b917b7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 10 Jun 2024 19:27:43 -0500 Subject: [PATCH] SSL support for websocket server, basic world editor, various prompt fixes --- client/src/events.tsx | 2 +- client/src/main.tsx | 23 ++- prompts/llama-base.yml | 20 ++- taleweave/bot/discord.py | 12 +- taleweave/editor.py | 323 ++++++++++++++++++++++++++++++++++ taleweave/generate.py | 22 ++- taleweave/main.py | 17 +- taleweave/models/config.py | 8 + taleweave/server/websocket.py | 16 +- taleweave/simulate.py | 71 ++++++-- taleweave/systems/digest.py | 19 +- taleweave/utils/template.py | 3 + 12 files changed, 489 insertions(+), 47 deletions(-) create mode 100644 taleweave/editor.py diff --git a/client/src/events.tsx b/client/src/events.tsx index 6a76514..88bc525 100644 --- a/client/src/events.tsx +++ b/client/src/events.tsx @@ -230,7 +230,7 @@ export function RenderEventItem(props: EventItemProps) { >{prompt} } /> - + {Object.entries(images).map(([name, image]) => openImage(image as string)}> Render diff --git a/client/src/main.tsx b/client/src/main.tsx index bacffba..cd99e6e 100644 --- a/client/src/main.tsx +++ b/client/src/main.tsx @@ -1,9 +1,24 @@ -import { doesExist } from '@apextoaster/js-utils'; +import { doesExist, mustDefault } from '@apextoaster/js-utils'; import { createRoot } from 'react-dom/client'; import React, { StrictMode } from 'react'; import { App } from './app.js'; +export const DEFAULT_SOCKET_PORT = 8001; + +export function getSocketProtocol(protocol: string) { + if (protocol === 'https:') { + return 'wss:'; + } + + return 'ws:'; +} + +export function getSocketAddress(protocol: string, hostname: string, port = DEFAULT_SOCKET_PORT) { + const socketProtocol = getSocketProtocol(protocol); + return `${socketProtocol}://${hostname}:${port}/`; +} + window.addEventListener('DOMContentLoaded', () => { const history = document.querySelector('#history'); @@ -12,7 +27,11 @@ window.addEventListener('DOMContentLoaded', () => { throw new Error('History element not found'); } + const protocol = window.location.protocol; const hostname = window.location.hostname; + const search = new URLSearchParams(window.location.search); + const socketAddress = mustDefault(search.get('socket'), getSocketAddress(protocol, hostname)); + const root = createRoot(history); - root.render(); + root.render(); }); diff --git a/prompts/llama-base.yml b/prompts/llama-base.yml index ebe7b6d..d5d250f 100644 --- a/prompts/llama-base.yml +++ b/prompts/llama-base.yml @@ -193,9 +193,9 @@ prompts: {{name}} will happen in {{turns}} turn # agent stuff - world_agent_backstory: | - {{ character.backstory }} world_agent_backstory_other: | + {{ character.backstory }} + world_agent_backstory: | You are {{character | name}}, a character in a text-based role-playing game. Your character's backstory is: {{ character.backstory }} Explore the world, interact with other characters, and complete quests to advance the story. @@ -218,11 +218,11 @@ prompts: Generating a {{theme}} with {{room_count}} rooms world_generate_room_name: | - Generate one room, area, or location that would make sense in the world of {{world_theme}}. + Generate one room, area, or location that would make sense in the world of {{world_theme}}. {{ additional_prompt | punctuate }} Only respond with the room name in title case, do not include the description or any other text. Do not prefix the name with "the", do not wrap it in quotes. The existing rooms are: {{existing_rooms}} world_generate_room_description: | - Generate a detailed description of the {{name}} area. What does it look like? + Generate a detailed description of the {{name}} area. {{ additional_prompt | punctuate }} What does it look like? What does it smell like? What can be seen or heard? world_generate_room_broadcast_room: | Generating room: {{name}} @@ -356,6 +356,11 @@ prompts: world_simulate_character_action_error_json: | Your last reply was not a valid action or the action you tried to use does not exist. Please try again, being careful to reply with a valid function call in JSON format. The available actions are: {{actions}}. + world_simulate_character_action_error_action: | + You cannot use the '{{action}}' action because {{message | punctuate}} + world_simulate_character_action_error_unknown_tool: | + That action is not available during the action phase or it does not exist. Please try again using a different + action. The available actions are: {{actions}}. world_simulate_character_planning: | You are about to start your turn. Plan your next action carefully. Take notes and schedule events to help keep track of your goals. @@ -379,6 +384,11 @@ prompts: You have no upcoming events. world_simulate_character_planning_events_item: | {{event.name}} in {{turns}} turns + world_simulate_character_planning_error_action: | + You cannot perform the '{{action}}' action because {{message | punctuate}} world_simulate_character_planning_error_json: | Your last reply was not a valid action or the action you tried to use does not exist. Please try again, being - careful to reply with a valid function call in JSON format. The available actions are: {{actions}}. \ No newline at end of file + careful to reply with a valid function call in JSON format. The available actions are: {{actions}}. + world_simulate_character_planning_error_unknown_tool: | + That action is not available during the planning phase or it does not exist. Please try again using a different + action. The available actions are: {{actions}}. \ No newline at end of file diff --git a/taleweave/bot/discord.py b/taleweave/bot/discord.py index 0b31260..5a6a9cf 100644 --- a/taleweave/bot/discord.py +++ b/taleweave/bot/discord.py @@ -238,17 +238,17 @@ def bot_main(): client.run(environ["DISCORD_TOKEN"]) def send_main(): - from time import sleep + # from time import sleep while True: - sleep(0.1) - if event_queue.empty(): - # logger.debug("no events to prompt") - continue + # sleep(0.05) + # if event_queue.empty(): + # logger.debug("no events to prompt") + # continue # wait for pending messages to send, to keep them in order if len(active_tasks) > 0: - logger.debug("waiting for active tasks to complete") + # logger.debug("waiting for active tasks to complete") continue event = event_queue.get() diff --git a/taleweave/editor.py b/taleweave/editor.py new file mode 100644 index 0000000..89d3767 --- /dev/null +++ b/taleweave/editor.py @@ -0,0 +1,323 @@ +import argparse +from os import path +from typing import List, Tuple + +from taleweave.context import get_dungeon_master, get_game_systems, set_game_systems +from taleweave.game_system import GameSystem +from taleweave.generate import ( + generate_character, + generate_item, + generate_portals, + generate_room, +) +from taleweave.main import load_or_initialize_system_data +from taleweave.models.base import dump_model +from taleweave.models.entity import World, WorldState +from taleweave.plugins import load_plugin +from taleweave.utils.file import load_yaml, save_yaml +from taleweave.utils.search import ( + find_character, + find_item, + find_portal, + find_room, + list_characters, + list_items, + list_portals, + list_rooms, +) +from taleweave.utils.world import describe_entity + + +ENTITY_TYPES = ["room", "portal", "item", "character"] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Taleweave Editor") + parser.add_argument("--state", type=str, help="State file to edit") + parser.add_argument("--world", type=str, help="World file to edit") + parser.add_argument("--systems", type=str, nargs="*", help="Game systems to load") + + subparsers = parser.add_subparsers(dest="command", help="Command to execute") + subparsers.required = True + + # Set up the 'list' command + list_parser = subparsers.add_parser( + "list", help="List all entities or entities of a specific type" + ) + list_parser.add_argument( + "type", help="Type of entity to list", choices=ENTITY_TYPES, nargs="?" + ) + + # Set up the 'describe' command + describe_parser = subparsers.add_parser("describe", help="Describe an entity") + describe_parser.add_argument( + "type", help="Type of entity to describe", choices=ENTITY_TYPES + ) + describe_parser.add_argument("entity", type=str, help="Entity to describe") + + # Set up the 'create' command + create_parser = subparsers.add_parser("create", help="Create an entity") + create_parser.add_argument( + "type", help="Type of entity to create", choices=ENTITY_TYPES + ) + create_parser.add_argument("name", type=str, help="Name of the entity to create") + create_parser.add_argument("--room", type=str, help="Room the entity is in") + + # Set up the 'generate' command + generate_parser = subparsers.add_parser("generate", help="Generate an entity") + generate_parser.add_argument( + "type", help="Type of entity to generate", choices=ENTITY_TYPES + ) + generate_parser.add_argument( + "prompt", type=str, help="Prompt to generate the entity" + ) + generate_parser.add_argument("--room", type=str, help="Room the entity is in") + + # Set up the 'delete' command + delete_parser = subparsers.add_parser("delete", help="Delete an entity") + delete_parser.add_argument( + "type", help="Type of entity to delete", choices=ENTITY_TYPES + ) + delete_parser.add_argument("entity", type=str, help="Entity to delete") + + # Set up the 'update' command + update_parser = subparsers.add_parser("update", help="Update an entity") + update_parser.add_argument( + "type", help="Type of entity to update", choices=ENTITY_TYPES + ) + update_parser.add_argument("entity", type=str, help="Entity to update") + update_parser.add_argument("--backstory", type=str, help="Backstory of the entity") + update_parser.add_argument( + "--description", type=str, help="Description of the entity" + ) + + return parser.parse_args() + + +def load_world(state_file, world_file) -> Tuple[World, WorldState | None]: + systems = [] + + if state_file and path.exists(state_file): + with open(state_file, "r") as f: + state = WorldState(**load_yaml(f)) + + load_or_initialize_system_data(world_file, systems, state.world) + + return (state.world, state) + + if world_file and path.exists(world_file): + with open(world_file, "r") as f: + world = World(**load_yaml(f)) + + load_or_initialize_system_data(world_file, systems, world) + + return (world, None) + + raise ValueError("No state or world file found") + + +def save_world(state_file, world_file, world: World, state: WorldState | None): + if state: + print(f"Saving world {world.name} to {state_file}") + return + + with open(state_file, "w") as f: + save_yaml(f, dump_model(WorldState, state)) + else: + print(f"Saving world {world.name} to {world_file}") + return + + with open(world_file, "w") as f: + save_yaml(f, dump_model(World, world)) + + +def command_list(args): + print(f"Listing {args.type}s") + world, _ = load_world(args.state, args.world) + print(world.name) + + if args.type == "room": + for room in list_rooms(world): + print(room.name) + + if args.type == "portal": + for portal in list_portals(world): + print(portal.name) + + if args.type == "item": + for item in list_items( + world, include_character_inventory=True, include_item_inventory=True + ): + print(item.name) + + if args.type == "character": + for character in list_characters(world): + print(character.name) + + +def command_describe(args): + print(f"Describing {args.entity}") + world, _ = load_world(args.state, args.world) + print(world.name) + + if args.type == "room": + room = find_room(world, args.entity) + if not room: + print(f"Room {args.entity} not found") + else: + print(describe_entity(room)) + + if args.type == "portal": + portal = find_portal(world, args.entity) + if not portal: + print(f"Portal {args.entity} not found") + else: + print(describe_entity(portal)) + + if args.type == "item": + item = find_item( + world, + args.entity, + include_character_inventory=True, + include_item_inventory=True, + ) + if not item: + print(f"Item {args.entity} not found") + else: + print(describe_entity(item)) + + if args.type == "character": + character = find_character(world, args.entity) + if not character: + print(f"Character {args.entity} not found") + else: + print(describe_entity(character)) + + +def command_create(args): + print(f"Create {args.type} named {args.name}") + world, state = load_world(args.state, args.world) + print(world.name) + + # TODO: Create the entity + + save_world(args.state, args.world, world, state) + + +def command_generate(args): + print(f"Generate {args.type} with prompt: {args.prompt}") + world, state = load_world(args.state, args.world) + print(world.name) + + dungeon_master = get_dungeon_master() + systems = get_game_systems() + + # TODO: Generate the entity + if args.type == "room": + room = generate_room(dungeon_master, world, systems) + world.rooms.append(room) + + if args.type == "portal": + portal = generate_portals(dungeon_master, world, "TODO", "TODO", systems) + # TODO: Add portal to room and generate reverse portal from destination room + + if args.type == "item": + item = generate_item(dungeon_master, world, systems) + # TODO: Add item to room or character inventory + + if args.type == "character": + character = generate_character( + dungeon_master, world, systems, "TODO", args.prompt + ) + # TODO: Add character to room + + save_world(args.state, args.world, world, state) + + +def command_delete(args): + print(f"Delete {args.entity}") + world, state = load_world(args.state, args.world) + print(world.name) + + # TODO: Delete the entity + + save_world(args.state, args.world, world, state) + + +def command_update(args): + print(f"Update {args.entity}") + world, state = load_world(args.state, args.world) + print(world.name) + + if args.type == "room": + room = find_room(world, args.entity) + if not room: + print(f"Room {args.entity} not found") + else: + print(describe_entity(room)) + + if args.type == "portal": + portal = find_portal(world, args.entity) + if not portal: + print(f"Portal {args.entity} not found") + else: + print(describe_entity(portal)) + + if args.type == "item": + item = find_item( + world, + args.entity, + include_character_inventory=True, + include_item_inventory=True, + ) + if not item: + print(f"Item {args.entity} not found") + else: + print(describe_entity(item)) + + if args.type == "character": + character = find_character(world, args.entity) + if not character: + print(f"Character {args.entity} not found") + else: + if args.backstory: + character.backstory = args.backstory + + if args.description: + character.description = args.description + + print(describe_entity(character)) + + save_world(args.state, args.world, world, state) + + +COMMAND_TABLE = { + "list": command_list, + "describe": command_describe, + "create": command_create, + "generate": command_generate, + "delete": command_delete, + "update": command_update, +} + + +def main(): + args = parse_args() + print(args) + + # load game systems before executing commands + systems: List[GameSystem] = [] + for system_name in args.systems or []: + print(f"loading extra systems from {system_name}") + module_systems = load_plugin(system_name) + print(f"loaded extra systems: {module_systems}") + systems.extend(module_systems) + + set_game_systems(systems) + + command = COMMAND_TABLE[args.command] + command(args) + + +if __name__ == "__main__": + main() diff --git a/taleweave/generate.py b/taleweave/generate.py index 56707dc..c5ef13d 100644 --- a/taleweave/generate.py +++ b/taleweave/generate.py @@ -102,6 +102,7 @@ def generate_room( agent: Agent, world: World, systems: List[GameSystem], + additional_prompt: str = "", current_room: int | None = None, total_rooms: int | None = None, ) -> Room: @@ -111,6 +112,7 @@ def generate_room( agent, get_prompt("world_generate_room_name"), context={ + "additional_prompt": additional_prompt, "world_theme": world.theme, "existing_rooms": existing_rooms, "current_room": current_room, @@ -121,7 +123,13 @@ def generate_room( ) broadcast_generated(format_prompt("world_generate_room_broadcast_room", name=name)) - desc = agent(get_prompt("world_generate_room_description"), name=name) + desc = agent( + format_prompt( + "world_generate_room_description", + name=name, + additional_prompt=additional_prompt, + ) + ) actions = {} room = Room(name=name, description=desc, items=[], characters=[], actions=actions) @@ -581,12 +589,6 @@ def generate_world( world = World(name=name, rooms=[], theme=theme, order=[]) set_current_world(world) - # initialize the systems - for system in systems: - if system.initialize: - data = system.initialize(world) - set_system_data(system.name, data) - # generate the rooms for i in range(room_count): try: @@ -603,6 +605,12 @@ def generate_world( # generate portals to link the rooms together link_rooms(agent, world, systems) + # initialize the systems + for system in systems: + if system.initialize: + data = system.initialize(world) + set_system_data(system.name, data) + # ensure characters act in a stable order world.order = [ character.name for room in world.rooms for character in room.characters diff --git a/taleweave/main.py b/taleweave/main.py index b87d2e1..e9f318c 100644 --- a/taleweave/main.py +++ b/taleweave/main.py @@ -212,10 +212,12 @@ def load_prompt_library(args) -> None: return None -def load_or_initialize_system_data(args, systems: List[GameSystem], world: World): +def load_or_initialize_system_data( + world_path: str, systems: List[GameSystem], world: World +): for system in systems: if system.data: - system_data_file = f"{args.world}.{system.name}.json" + system_data_file = f"{world_path}.{system.name}.json" if path.exists(system_data_file): logger.info(f"loading system data from {system_data_file}") @@ -273,7 +275,7 @@ def load_or_generate_world( state = WorldState(**load_yaml(f)) set_current_turn(state.turn) - load_or_initialize_system_data(args, systems, state.world) + load_or_initialize_system_data(args.world, systems, state.world) memory = state.memory turn = state.turn @@ -283,7 +285,7 @@ def load_or_generate_world( with open(world_file, "r") as f: world = World(**load_yaml(f)) - load_or_initialize_system_data(args, systems, world) + load_or_initialize_system_data(args.world, systems, world) else: logger.info(f"generating a new world using theme: {world_prompt.theme}") world = generate_world( @@ -293,10 +295,11 @@ def load_or_generate_world( systems, room_count=args.rooms, ) - load_or_initialize_system_data(args, systems, world) + load_or_initialize_system_data(args.world, systems, world) - save_world(world, world_file) - save_system_data(args, systems) + # TODO: check if there have been any changes before saving + save_world(world, world_file) + save_system_data(args, systems) new_rooms = [] for i in range(add_rooms): diff --git a/taleweave/models/config.py b/taleweave/models/config.py index f780b8b..003c155 100644 --- a/taleweave/models/config.py +++ b/taleweave/models/config.py @@ -33,10 +33,18 @@ class RenderConfig: steps: int | IntRange +@dataclass +class WebsocketServerSSLConfig: + cert: str + key: str | None = None + password: str | None = None + + @dataclass class WebsocketServerConfig: host: str port: int + ssl: WebsocketServerSSLConfig | None = None @dataclass diff --git a/taleweave/server/websocket.py b/taleweave/server/websocket.py index 943729d..fc69088 100644 --- a/taleweave/server/websocket.py +++ b/taleweave/server/websocket.py @@ -324,8 +324,22 @@ def run_sockets(): async def server_main(): config = get_game_config() + ssl_context = None + if config.server.websocket.ssl: + from ssl import PROTOCOL_TLS_SERVER, SSLContext + + ssl_context = SSLContext(PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain( + config.server.websocket.ssl.cert, + keyfile=config.server.websocket.ssl.key, + password=config.server.websocket.ssl.password, + ) + async with websockets.serve( - handler, config.server.websocket.host, config.server.websocket.port + handler, + config.server.websocket.host, + config.server.websocket.port, + ssl=ssl_context, ): logger.info("websocket server started") await asyncio.Future() # run forever diff --git a/taleweave/simulate.py b/taleweave/simulate.py index 83cd758..f66ea9e 100644 --- a/taleweave/simulate.py +++ b/taleweave/simulate.py @@ -103,6 +103,10 @@ def result_parser(value, **kwargs): # trim suffixes that are used elsewhere value = value.removesuffix("END").strip() + # fix the "action_ move" whitespace issue + if '"action_ ' in value: + value = value.replace('"action_ ', '"action_') + # fix unbalanced curly braces if value.startswith("{") and not value.endswith("}"): open_count = value.count("{") @@ -124,12 +128,35 @@ def result_parser(value, **kwargs): broadcast(event) return result - except ToolError: - raise ActionError( - format_prompt( - "world_simulate_character_action_error_json", actions=action_names + except ToolError as e: + e_str = str(e) + if e_str and "Error running tool" in e_str: + # extract the tool name and rest of the message from the error + # the format is: "Error running tool: : " + action_name, message = e_str.split(":", 1) + action_name = action_name.removeprefix("Error running tool").strip() + message = message.strip() + raise ActionError( + format_prompt( + "world_simulate_character_action_error_action", + action=action_name, + message=message, + ) + ) + elif e_str and "Unknown tool" in e_str: + raise ActionError( + format_prompt( + "world_simulate_character_action_error_unknown_tool", + actions=action_names, + ) + ) + else: + raise ActionError( + format_prompt( + "world_simulate_character_action_error_json", + actions=action_names, + ) ) - ) # prompt and act logger.info("starting turn for character: %s", character.name) @@ -209,13 +236,35 @@ def prompt_character_planning( def result_parser(value, **kwargs): try: return function_result(value, **kwargs) - except ToolError: - raise ActionError( - format_prompt( - "world_simulate_character_planning_error_json", - actions=planner_toolbox.list_tools(), + except ToolError as e: + e_str = str(e) + if e_str and "Error running tool" in e_str: + # extract the tool name and rest of the message from the error + # the format is: "Error running tool: : " + action_name, message = e_str.split(":", 2) + action_name = action_name.removeprefix("Error running tool").strip() + message = message.strip() + raise ActionError( + format_prompt( + "world_simulate_character_planning_error_action", + action=action_name, + message=message, + ) + ) + elif e_str and "Unknown tool" in e_str: + raise ActionError( + format_prompt( + "world_simulate_character_planning_error_unknown_tool", + actions=planner_toolbox.list_tools(), + ) + ) + else: + raise ActionError( + format_prompt( + "world_simulate_character_planning_error_json", + actions=planner_toolbox.list_tools(), + ) ) - ) logger.info("starting planning for character: %s", character.name) _, condition_end, result_parser = make_keyword_condition( diff --git a/taleweave/systems/digest.py b/taleweave/systems/digest.py index e81e0e5..ed34e91 100644 --- a/taleweave/systems/digest.py +++ b/taleweave/systems/digest.py @@ -41,11 +41,11 @@ def create_move_digest( if not source_portal: raise ValueError(f"Could not find source portal for {destination_portal.name}") - mode = "self" if (event.character == active_character) else "other" - mood = "enter" if (destination_room == active_room) else "exit" + character_mode = "self" if (event.character == active_character) else "other" + direction_mode = "enter" if (destination_room == active_room) else "exit" message = format_str( - f"digest_move_{mode}_{mood}", + f"digest_move_{character_mode}_{direction_mode}", destination_portal=destination_portal, destination_room=destination_room, direction=direction, @@ -67,10 +67,15 @@ def create_turn_digest( if isinstance(event, ActionEvent): # special handling for move actions if event.action == "action_move": - message = create_move_digest( - world, active_room, active_character, event - ) - messages.append(message) + try: + message = create_move_digest( + world, active_room, active_character, event + ) + messages.append(message) + except Exception: + logger.exception( + "error formatting digest for move event: %s", event + ) elif event.character == active_character or event.room == active_room: prompt_key = f"digest_{event.action}" if prompt_key in library.prompts: diff --git a/taleweave/utils/template.py b/taleweave/utils/template.py index e5c6ebe..1640ddd 100644 --- a/taleweave/utils/template.py +++ b/taleweave/utils/template.py @@ -29,6 +29,9 @@ def the_prefix(name: str) -> str: def punctuate(name: str, suffix: str = ".") -> str: + if len(name) == 0: + return name + if name[-1] in [".", "!", "?", suffix]: return name