Skip to content

Commit

Permalink
make sure config is used consistently, start adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jun 1, 2024
1 parent be37d58 commit ef8529e
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 69 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ venv/
client/node_modules/
client/out/
taleweave/custom_*
.coverage
coverage.*
80 changes: 80 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# TaleWeave AI Command Line Options

The following command line arguments are available when launching the TaleWeave AI engine:

- **--actions**
- **Type:** String
- **Description:** Additional actions to include in the simulation. Note: More than one argument is allowed.

- **--add-rooms**
- **Type:** Integer
- **Default:** 0
- **Description:** The number of new rooms to generate before starting the simulation.

- **--config**
- **Type:** String
- **Description:** The file to load additional configuration from.

- **--discord**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the simulation in a Discord bot.

- **--flavor**
- **Type:** String
- **Default:** ""
- **Description:** Additional flavor text for the generated world.

- **--optional-actions**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Include optional actions in the simulation.

- **--player**
- **Type:** String
- **Description:** The name of the character to play as.

- **--prompts**
- **Type:** String
- **Description:** The file to load game prompts from. Note: More than one argument is allowed.

- **--render**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the render thread.

- **--render-generated**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Render entities as they are generated.

- **--rooms**
- **Type:** Integer
- **Description:** The number of rooms to generate.

- **--server**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the websocket server.

- **--state**
- **Type:** String
- **Description:** The file to save the world state to. Defaults to `$world.state.json` if not set.

- **--turns**
- **Type:** Integer or "inf"
- **Default:** 10
- **Description:** The number of simulation turns to run.

- **--systems**
- **Type:** String
- **Description:** Extra systems to run in the simulation. Note: More than one argument is allowed.

- **--theme**
- **Type:** String
- **Default:** "fantasy"
- **Description:** The theme of the generated world.

- **--world**
- **Type:** String
- **Default:** "world"
- **Description:** The file to save the generated world to.

- **--world-template**
- **Type:** String
- **Description:** The template file to load the world prompt from.
18 changes: 11 additions & 7 deletions taleweave/actions/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
action_context,
get_agent_for_character,
get_current_turn,
get_game_config,
get_prompt,
)
from taleweave.errors import ActionError
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.models.planning import CalendarEvent
from taleweave.utils.planning import get_recent_notes
from taleweave.utils.prompt import format_prompt

character_config = DEFAULT_CONFIG.world.character


def take_note(fact: str):
"""
Expand All @@ -22,11 +20,13 @@ def take_note(fact: str):
fact: The fact to remember.
"""

config = get_game_config()

with action_context() as (_, action_character):
if fact in action_character.planner.notes:
raise ActionError(get_prompt("action_take_note_error_duplicate"))

if len(action_character.planner.notes) >= character_config.note_limit:
if len(action_character.planner.notes) >= config.world.character.note_limit:
raise ActionError(get_prompt("action_take_note_error_limit"))

action_character.planner.notes.append(fact)
Expand Down Expand Up @@ -103,6 +103,8 @@ def summarize_notes(limit: int) -> str:
limit: The maximum number of notes to keep.
"""

config = get_game_config()

with action_context() as (_, action_character):
notes = action_character.planner.notes
if len(notes) == 0:
Expand All @@ -120,11 +122,11 @@ def summarize_notes(limit: int) -> str:
)

new_notes = [note.strip() for note in summary.split("\n") if note.strip()]
if len(new_notes) > character_config.note_limit:
if len(new_notes) > config.world.character.note_limit:
raise ActionError(
format_prompt(
"action_summarize_notes_error_limit",
limit=character_config.note_limit,
limit=config.world.character.note_limit,
)
)

Expand Down Expand Up @@ -165,7 +167,9 @@ def check_calendar(count: int):
count: The number of upcoming events to read. 5 is usually a good number.
"""

count = min(count, character_config.event_limit)
config = get_game_config()

count = min(count, config.world.character.event_limit)
current_turn = get_current_turn()

with action_context() as (_, action_character):
Expand Down
24 changes: 13 additions & 11 deletions taleweave/bot/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
broadcast,
get_character_agent_for_name,
get_current_world,
get_game_config,
set_character_agent,
subscribe,
)
from taleweave.models.config import DEFAULT_CONFIG, DiscordBotConfig
from taleweave.models.config import DiscordBotConfig
from taleweave.models.event import (
ActionEvent,
GameEvent,
Expand All @@ -38,7 +39,6 @@

logger = getLogger(__name__)
client = None
bot_config: DiscordBotConfig = DEFAULT_CONFIG.bot.discord

active_tasks = set()
event_messages: Dict[int, str | GameEvent] = {}
Expand Down Expand Up @@ -78,29 +78,32 @@ async def on_message(self, message):
if message.author == self.user:
return

config = get_game_config()
author = message.author
channel = message.channel
user_name = author.name # include nick

if message.content.startswith(
bot_config.command_prefix + bot_config.name_command
config.bot.discord.command_prefix + config.bot.discord.name_command
):
world = get_current_world()
if world:
world_message = format_prompt(
"discord_world_active", bot_name=bot_config.name_title, world=world
"discord_world_active",
bot_name=config.bot.discord.name_title,
world=world,
)
else:
world_message = format_prompt(
"discord_world_none", bot_name=bot_config.name_title
"discord_world_none", bot_name=config.bot.discord.name_title
)

await message.channel.send(world_message)
return

if message.content.startswith("!help"):
await message.channel.send(
format_prompt("discord_help", bot_name=bot_config.name_command)
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
)
return

Expand Down Expand Up @@ -172,14 +175,11 @@ def prompt_player(event: PromptEvent):


def launch_bot(config: DiscordBotConfig):
global bot_config
global client

bot_config = config

# message contents need to be enabled for multi-server bots
intents = Intents.default()
if bot_config.content_intent:
if config.content_intent:
intents.message_content = True

client = AdventureClient(intents=intents)
Expand Down Expand Up @@ -246,12 +246,14 @@ def get_active_channels():
if not client:
return []

config = get_game_config()

# return client.private_channels
return [
channel
for guild in client.guilds
for channel in guild.text_channels
if channel.name in bot_config.channels
if channel.name in config.bot.discord.channels
]


Expand Down
15 changes: 15 additions & 0 deletions taleweave/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pyee.base import EventEmitter

from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG, Config
from taleweave.models.entity import Character, Room, World
from taleweave.models.event import GameEvent, StatusEvent
from taleweave.models.prompt import PromptLibrary
Expand All @@ -34,6 +35,7 @@

# game context
event_emitter = EventEmitter()
game_config: Config = DEFAULT_CONFIG
game_systems: List[GameSystem] = []
prompt_library: PromptLibrary = PromptLibrary(prompts={})
system_data: Dict[str, Any] = {}
Expand Down Expand Up @@ -160,6 +162,10 @@ def get_dungeon_master() -> Agent:
return dungeon_master


def get_game_config() -> Config:
return game_config


def get_game_systems() -> List[GameSystem]:
return game_systems

Expand All @@ -172,6 +178,10 @@ def get_prompt_library() -> PromptLibrary:
return prompt_library


def get_system_config(system: str) -> Any | None:
return game_config.systems.data.get(system)


def get_system_data(system: str) -> Any | None:
return system_data.get(system)

Expand Down Expand Up @@ -209,6 +219,11 @@ def set_dungeon_master(agent):
dungeon_master = agent


def set_game_config(config: Config):
global game_config
game_config = config


def set_game_systems(systems: Sequence[GameSystem]):
global game_systems
game_systems = list(systems)
Expand Down
19 changes: 16 additions & 3 deletions taleweave/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
from packit.results import enum_result, int_result
from packit.utils import could_be_json

from taleweave.context import broadcast, get_prompt, set_current_world, set_system_data
from taleweave.context import (
broadcast,
get_game_config,
get_prompt,
set_current_world,
set_system_data,
)
from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG, WorldConfig
from taleweave.models.effect import (
EffectPattern,
FloatEffectPattern,
Expand All @@ -33,7 +38,10 @@

logger = getLogger(__name__)

world_config: WorldConfig = DEFAULT_CONFIG.world

def get_world_config():
config = get_game_config()
return config.world


def duplicate_name_parser(existing_names: List[str]):
Expand Down Expand Up @@ -112,6 +120,7 @@ def generate_room(
actions = {}
room = Room(name=name, description=desc, items=[], characters=[], actions=actions)

world_config = get_world_config()
item_count = resolve_int_range(world_config.size.room_items) or 0
broadcast_generated(
format_prompt(
Expand Down Expand Up @@ -276,6 +285,7 @@ def generate_item(
item = Item(name=name, description=desc, actions=actions)
generate_system_attributes(agent, world, item, systems)

world_config = get_world_config()
effect_count = resolve_int_range(world_config.size.item_effects) or 0
broadcast_generated(
message=format_prompt(
Expand Down Expand Up @@ -343,6 +353,7 @@ def generate_character(
generate_system_attributes(agent, world, character, systems)

# generate the character's inventory
world_config = get_world_config()
item_count = resolve_int_range(world_config.size.character_items) or 0
broadcast_generated(
message=format_prompt(
Expand Down Expand Up @@ -499,6 +510,7 @@ def link_rooms(
rooms: List[Room] | None = None,
) -> None:
rooms = rooms or world.rooms
world_config = get_world_config()

for room in rooms:
num_portals = resolve_int_range(world_config.size.portals) or 0
Expand Down Expand Up @@ -550,6 +562,7 @@ def generate_world(
systems: List[GameSystem],
room_count: int | None = None,
) -> World:
world_config = get_world_config()
room_count = room_count or resolve_int_range(world_config.size.rooms) or 0

broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme"))
Expand Down
4 changes: 3 additions & 1 deletion taleweave/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@

logger = logger_with_colors(__name__) # , level="DEBUG")

load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True)
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)

if True:
from taleweave.context import (
get_prompt_library,
get_system_data,
set_current_turn,
set_dungeon_master,
set_game_config,
set_system_data,
subscribe,
)
Expand Down Expand Up @@ -312,6 +313,7 @@ def main():
if args.config:
with open(args.config, "r") as f:
config = Config(**load_yaml(f))
set_game_config(config)
else:
config = DEFAULT_CONFIG

Expand Down
Loading

0 comments on commit ef8529e

Please sign in to comment.