Skip to content

Commit

Permalink
dry model dumping, improve DM prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jun 8, 2024
1 parent 9f435ee commit e1c72c3
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 24 deletions.
8 changes: 4 additions & 4 deletions prompts/llama-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ prompts:
# world generation
world_generate_dungeon_master: |
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme}}. Be
creative and original, creating a world that is visually detailed and full of curious details. Do not repeat
yourself unless you are given the same prompt with the same characters, room, and context. {{flavor}}. The theme of
the world must be: {{theme}}.
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme | punctuate}} Be
creative and original, creating a world that is visually detailed, consistent, and plausible within the context of
the setting. Do not repeat yourself unless you are given the same prompt with the same characters, room, and
context. {{flavor | punctuate}} The theme of the world must be: {{theme | punctuate}}
world_generate_world_broadcast_theme: |
Generating a {{theme}} with {{room_count}} rooms
Expand Down
24 changes: 14 additions & 10 deletions taleweave/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from packit.memory import make_limited_memory
from packit.utils import logger_with_colors

# configure logging
# this is the only taleweave import allowed before the logger has been created
from taleweave.utils.file import load_yaml

# configure logging
LOG_PATH = "logging.json"
try:
if path.exists(LOG_PATH):
Expand All @@ -30,6 +31,15 @@

load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)

# start the debugger, if needed
if environ.get("DEBUG", "false").lower() == "true":
import debugpy

debugpy.listen(5679)
logger.info("waiting for debugger to attach...")
debugpy.wait_for_client()


if True:
from taleweave.context import (
get_prompt_library,
Expand All @@ -52,14 +62,6 @@
from taleweave.state import create_agents, save_world, save_world_state
from taleweave.utils.template import format_prompt

# start the debugger, if needed
if environ.get("DEBUG", "false").lower() == "true":
import debugpy

debugpy.listen(5679)
logger.info("waiting for debugger to attach...")
debugpy.wait_for_client()


def int_or_inf(value: str) -> float | int:
if value == "inf":
Expand Down Expand Up @@ -415,7 +417,9 @@ def snapshot_system(world: World, turn: int, data: None = None) -> None:
world_builder = Agent(
"dungeon master",
format_prompt(
"world_generate_dungeon_master", flavor=args.flavor, theme=world.theme
"world_generate_dungeon_master",
flavor=world_prompt.flavor,
theme=world_prompt.theme,
),
{},
llm,
Expand Down
10 changes: 10 additions & 0 deletions taleweave/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Dict
from uuid import uuid4

from pydantic import RootModel

if TYPE_CHECKING:
from dataclasses import dataclass
else:
Expand All @@ -16,6 +18,14 @@ class BaseModel:
id: str


def dump_model(cls, model: BaseModel) -> Dict:
return RootModel[cls](model).model_dump()


def dump_model_json(cls, model: BaseModel) -> str:
return RootModel[cls](model).model_dump_json(indent=2)


def uuid() -> str:
return uuid4().hex

Expand Down
4 changes: 2 additions & 2 deletions taleweave/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import websockets
from PIL import Image
from pydantic import RootModel

from taleweave.context import (
broadcast,
Expand All @@ -20,6 +19,7 @@
set_character_agent,
subscribe,
)
from taleweave.models.base import dump_model
from taleweave.models.config import WebsocketServerConfig
from taleweave.models.entity import World, WorldEntity
from taleweave.models.event import (
Expand Down Expand Up @@ -343,7 +343,7 @@ def server_system(world: World, turn: int, data: Any | None = None):


def server_event(event: GameEvent):
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
json_event: Dict[str, Any] = dump_model(event.__class__, event)
json_event.update(
{
"id": event.id,
Expand Down
7 changes: 3 additions & 4 deletions taleweave/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel

from taleweave.context import (
get_all_character_agents,
get_game_config,
set_character_agent,
)
from taleweave.models.base import dump_model, dump_model_json
from taleweave.models.entity import World
from taleweave.player import LocalPlayer
from taleweave.utils.template import format_prompt
Expand Down Expand Up @@ -58,10 +58,9 @@ def graph_world(world: World, turn: int):

def snapshot_world(world: World, turn: int):
# save the world itself, along with the turn number and the memory of each agent
json_world = RootModel[World](world).model_dump()
json_world = dump_model(World, world)

json_memory = {}

for character, agent in get_all_character_agents():
json_memory[character.name] = list(agent.memory or [])

Expand Down Expand Up @@ -97,7 +96,7 @@ def restore_memory(

def save_world(world, filename):
with open(filename, "w") as f:
json_world = RootModel[World](world).model_dump_json(indent=2)
json_world = dump_model_json(World, world)
f.write(json_world)


Expand Down
5 changes: 2 additions & 3 deletions taleweave/utils/systems.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pydantic import RootModel

from taleweave.models.base import dump_model
from taleweave.utils.file import load_yaml, save_yaml


Expand All @@ -11,6 +10,6 @@ def load_system_data(cls, file):


def save_system_data(cls, file, model):
data = RootModel[cls](model).model_dump()
data = dump_model(cls, model)
with open(file, "w") as f:
save_yaml(f, data)
2 changes: 1 addition & 1 deletion taleweave/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def the_prefix(name: str) -> str:
return f"the {name}"


def punctuate(name: str, suffix: str) -> str:
def punctuate(name: str, suffix: str = ".") -> str:
if name[-1] in [".", "!", "?", suffix]:
return name

Expand Down

0 comments on commit e1c72c3

Please sign in to comment.