From e1c72c3717cd4881318a24a5a515a97a3d1d7856 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 7 Jun 2024 21:18:56 -0500 Subject: [PATCH] dry model dumping, improve DM prompt --- prompts/llama-base.yml | 8 ++++---- taleweave/main.py | 24 ++++++++++++++---------- taleweave/models/base.py | 10 ++++++++++ taleweave/server/websocket.py | 4 ++-- taleweave/state.py | 7 +++---- taleweave/utils/systems.py | 5 ++--- taleweave/utils/template.py | 2 +- 7 files changed, 36 insertions(+), 24 deletions(-) diff --git a/prompts/llama-base.yml b/prompts/llama-base.yml index c04d4b4..7e7e09f 100644 --- a/prompts/llama-base.yml +++ b/prompts/llama-base.yml @@ -209,10 +209,10 @@ prompts: # world generation world_generate_dungeon_master: | - You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme}}. Be - creative and original, creating a world that is visually detailed and full of curious details. Do not repeat - yourself unless you are given the same prompt with the same characters, room, and context. {{flavor}}. The theme of - the world must be: {{theme}}. + You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme | punctuate}} Be + creative and original, creating a world that is visually detailed, consistent, and plausible within the context of + the setting. Do not repeat yourself unless you are given the same prompt with the same characters, room, and + context. {{flavor | punctuate}} The theme of the world must be: {{theme | punctuate}} world_generate_world_broadcast_theme: | Generating a {{theme}} with {{room_count}} rooms diff --git a/taleweave/main.py b/taleweave/main.py index 402afa2..b87d2e1 100644 --- a/taleweave/main.py +++ b/taleweave/main.py @@ -10,9 +10,10 @@ from packit.memory import make_limited_memory from packit.utils import logger_with_colors +# configure logging +# this is the only taleweave import allowed before the logger has been created from taleweave.utils.file import load_yaml -# configure logging LOG_PATH = "logging.json" try: if path.exists(LOG_PATH): @@ -30,6 +31,15 @@ load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True) +# start the debugger, if needed +if environ.get("DEBUG", "false").lower() == "true": + import debugpy + + debugpy.listen(5679) + logger.info("waiting for debugger to attach...") + debugpy.wait_for_client() + + if True: from taleweave.context import ( get_prompt_library, @@ -52,14 +62,6 @@ from taleweave.state import create_agents, save_world, save_world_state from taleweave.utils.template import format_prompt -# start the debugger, if needed -if environ.get("DEBUG", "false").lower() == "true": - import debugpy - - debugpy.listen(5679) - logger.info("waiting for debugger to attach...") - debugpy.wait_for_client() - def int_or_inf(value: str) -> float | int: if value == "inf": @@ -415,7 +417,9 @@ def snapshot_system(world: World, turn: int, data: None = None) -> None: world_builder = Agent( "dungeon master", format_prompt( - "world_generate_dungeon_master", flavor=args.flavor, theme=world.theme + "world_generate_dungeon_master", + flavor=world_prompt.flavor, + theme=world_prompt.theme, ), {}, llm, diff --git a/taleweave/models/base.py b/taleweave/models/base.py index eb6ad1d..98f2cfe 100644 --- a/taleweave/models/base.py +++ b/taleweave/models/base.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING, Dict from uuid import uuid4 +from pydantic import RootModel + if TYPE_CHECKING: from dataclasses import dataclass else: @@ -16,6 +18,14 @@ class BaseModel: id: str +def dump_model(cls, model: BaseModel) -> Dict: + return RootModel[cls](model).model_dump() + + +def dump_model_json(cls, model: BaseModel) -> str: + return RootModel[cls](model).model_dump_json(indent=2) + + def uuid() -> str: return uuid4().hex diff --git a/taleweave/server/websocket.py b/taleweave/server/websocket.py index 9e7bc05..943729d 100644 --- a/taleweave/server/websocket.py +++ b/taleweave/server/websocket.py @@ -10,7 +10,6 @@ import websockets from PIL import Image -from pydantic import RootModel from taleweave.context import ( broadcast, @@ -20,6 +19,7 @@ set_character_agent, subscribe, ) +from taleweave.models.base import dump_model from taleweave.models.config import WebsocketServerConfig from taleweave.models.entity import World, WorldEntity from taleweave.models.event import ( @@ -343,7 +343,7 @@ def server_system(world: World, turn: int, data: Any | None = None): def server_event(event: GameEvent): - json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump() + json_event: Dict[str, Any] = dump_model(event.__class__, event) json_event.update( { "id": event.id, diff --git a/taleweave/state.py b/taleweave/state.py index fb9b927..7a461b4 100644 --- a/taleweave/state.py +++ b/taleweave/state.py @@ -5,13 +5,13 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from packit.agent import Agent, agent_easy_connect -from pydantic import RootModel from taleweave.context import ( get_all_character_agents, get_game_config, set_character_agent, ) +from taleweave.models.base import dump_model, dump_model_json from taleweave.models.entity import World from taleweave.player import LocalPlayer from taleweave.utils.template import format_prompt @@ -58,10 +58,9 @@ def graph_world(world: World, turn: int): def snapshot_world(world: World, turn: int): # save the world itself, along with the turn number and the memory of each agent - json_world = RootModel[World](world).model_dump() + json_world = dump_model(World, world) json_memory = {} - for character, agent in get_all_character_agents(): json_memory[character.name] = list(agent.memory or []) @@ -97,7 +96,7 @@ def restore_memory( def save_world(world, filename): with open(filename, "w") as f: - json_world = RootModel[World](world).model_dump_json(indent=2) + json_world = dump_model_json(World, world) f.write(json_world) diff --git a/taleweave/utils/systems.py b/taleweave/utils/systems.py index 1b006d8..e312a30 100644 --- a/taleweave/utils/systems.py +++ b/taleweave/utils/systems.py @@ -1,5 +1,4 @@ -from pydantic import RootModel - +from taleweave.models.base import dump_model from taleweave.utils.file import load_yaml, save_yaml @@ -11,6 +10,6 @@ def load_system_data(cls, file): def save_system_data(cls, file, model): - data = RootModel[cls](model).model_dump() + data = dump_model(cls, model) with open(file, "w") as f: save_yaml(f, data) diff --git a/taleweave/utils/template.py b/taleweave/utils/template.py index 0ad65ac..e5c6ebe 100644 --- a/taleweave/utils/template.py +++ b/taleweave/utils/template.py @@ -28,7 +28,7 @@ def the_prefix(name: str) -> str: return f"the {name}" -def punctuate(name: str, suffix: str) -> str: +def punctuate(name: str, suffix: str = ".") -> str: if name[-1] in [".", "!", "?", suffix]: return name