From e4b2fcbcd1e1559a9a0d2d39fe47e8c7b9e709d8 Mon Sep 17 00:00:00 2001 From: vcanaa Date: Fri, 19 May 2023 16:14:57 -0700 Subject: [PATCH] New towerfall client --- .gitignore | 1 - FollowCloseTargetCurriculum_episodes_20.json | 1 + bots/botquest.py | 6 +- bots/botrecorder.py | 4 +- common/__init__.py | 2 - common/controls.py | 17 +- common/entity.py | 6 +- common/gamereplay.py | 8 +- common/grid.py | 12 +- create_move_data.ipynb | 18 +- entity_envs/entity_base_env.py | 30 +-- entity_envs/entity_env.py | 14 +- entity_envs/predefined_envs.py | 17 +- envs/__init__.py | 1 - envs/base_env.py | 23 +- envs/blank_env.py | 16 +- envs/connection_provider.py | 240 ------------------- envs/curriculums.py | 14 +- envs/kill_enemy_objective.py | 16 +- envs/objectives.py | 12 +- envs/observations.py | 26 +- envs/predefined_envs.py | 14 +- evaluate_policy.py | 18 +- exper_rl_ppo_trainer_curr.py | 4 +- notebooks/train_move_data.ipynb | 10 +- run_simple_network_bot.py | 2 +- synchronization/__init__.py | 1 + {common => synchronization}/namedmutex.py | 0 tests/test_connection_provider.py | 138 ----------- tests/test_entity_env.py | 31 +-- tests/test_env.py | 22 +- tests/test_grid.py | 11 +- towerfall/__init__.py | 4 + {common => towerfall}/connection.py | 51 ++-- towerfall/towerfall.py | 210 ++++++++++++++++ train_kill_enemy_with_entity_env.py | 4 +- trainer/trainer.py | 32 +-- 37 files changed, 420 insertions(+), 616 deletions(-) create mode 100644 FollowCloseTargetCurriculum_episodes_20.json delete mode 100644 envs/connection_provider.py create mode 100644 synchronization/__init__.py rename {common => synchronization}/namedmutex.py (100%) delete mode 100644 tests/test_connection_provider.py create mode 100644 towerfall/__init__.py rename {common => towerfall}/connection.py (57%) create mode 100644 towerfall/towerfall.py diff --git a/.gitignore b/.gitignore index 64825e8..f04c3bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ __pycache__ .ipynb_checkpoints/ -.connection_provider/ .vscode/ config*.json diff --git a/FollowCloseTargetCurriculum_episodes_20.json b/FollowCloseTargetCurriculum_episodes_20.json new file mode 100644 index 0000000..5a0672e --- /dev/null +++ b/FollowCloseTargetCurriculum_episodes_20.json @@ -0,0 +1 @@ +[[{"x": 160, "y": 110, "__Vec2__": true}, {"x": 140, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 144.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 148.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 152.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 156.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 160.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 164.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 168.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 172.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 176.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 180.0, "y": 130, "__Vec2__": true}]] \ No newline at end of file diff --git a/bots/botquest.py b/bots/botquest.py index d2a08eb..0dd9699 100644 --- a/bots/botquest.py +++ b/bots/botquest.py @@ -6,7 +6,7 @@ from .bot import Bot -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple _HOST = "127.0.0.1" _PORT = 12024 @@ -44,14 +44,14 @@ def update(self): self.handle_update(game_state) - def handle_init(self, state: dict): + def handle_init(self, state: Dict[str, Any]): logging.info("handle_init") self.state_init = state random.seed(state['index']) self._connection.write('.') - def handle_scenario(self, state: dict): + def handle_scenario(self, state: Dict[str, Any]): logging.info("handle_scenario") self.state_scenario = state self.gv.set_scenario(state) diff --git a/bots/botrecorder.py b/bots/botrecorder.py index c781a29..5c91cb8 100644 --- a/bots/botrecorder.py +++ b/bots/botrecorder.py @@ -74,7 +74,7 @@ def update(self): self.handle_update(game_state) - def handle_init(self, state: dict): + def handle_init(self, state: Dict[str, Any]): logging.info("handle_init") if hasattr(self, 'replay'): dir_path = Path(os.path.join('replays', REPLAY_NAME)) @@ -88,7 +88,7 @@ def handle_init(self, state: dict): self.connection.write('.') - def handle_scenario(self, state: dict): + def handle_scenario(self, state: Dict[str, Any]): logging.info("handle_scenario") self.replay.handle_scenario(state) self.stateScenario = state diff --git a/common/__init__.py b/common/__init__.py index 272a34c..62db4b5 100644 --- a/common/__init__.py +++ b/common/__init__.py @@ -1,10 +1,8 @@ from .common import * -from .connection import * from .constants import * from .controls import * from .entity import * from .gamereplay import * from .grid import * from .logging_options import * -from .namedmutex import * from .pathing import * diff --git a/common/controls.py b/common/controls.py index 6934dc2..0eb64cb 100644 --- a/common/controls.py +++ b/common/controls.py @@ -1,20 +1,15 @@ -import sys import json import logging -import numpy as np - from math import atan2, pi -from threading import Thread, Lock - -from pyjoystick.sdl2 import Key, Joystick, run_event_loop - -from common import reply, Vec2 - -from .connection import Connection - +from threading import Lock, Thread from typing import Optional + +import numpy as np from numpy.typing import NDArray +from pyjoystick.sdl2 import Joystick, Key, run_event_loop +from common import Vec2, reply +from towerfall import Connection pi8 = pi / 8 diff --git a/common/entity.py b/common/entity.py index dd534cc..221a337 100644 --- a/common/entity.py +++ b/common/entity.py @@ -10,7 +10,7 @@ class Entity: - def __init__(self, e: dict): + def __init__(self, e: Dict[str, Any]): self.p: Vec2 = vec2_from_dict(e['pos']) self.v: Vec2 = vec2_from_dict(e['vel']) self.s: Vec2 = vec2_from_dict(e['size']) @@ -108,7 +108,7 @@ def div(self, f: float): self.y /= f -def vec2_from_dict(p: dict) -> Vec2: +def vec2_from_dict(p: Dict[str, Any]) -> Vec2: try: return Vec2(p['x'], p['y']) except KeyError: @@ -116,7 +116,7 @@ def vec2_from_dict(p: dict) -> Vec2: raise -def to_entities(entities: List[dict]) -> List[Entity]: +def to_entities(entities: List[Dict[str, Any]]) -> List[Entity]: result = [] for e in entities: result.append(Entity(e)) diff --git a/common/gamereplay.py b/common/gamereplay.py index 03a20f7..fa5c408 100644 --- a/common/gamereplay.py +++ b/common/gamereplay.py @@ -1,12 +1,12 @@ import json -from typing import List +from typing import Any, Dict, List class GameReplay: def __init__(self): - self.state_init: dict - self.state_scenario: dict - self.state_update: List[dict] = [] + self.state_init: Dict[str, Any] + self.state_scenario: Dict[str, Any] + self.state_update: List[Dict[str, Any]] = [] self.actions: List[str] = [] def handle_init(self, state): diff --git a/common/grid.py b/common/grid.py index f083b0a..142e7d1 100644 --- a/common/grid.py +++ b/common/grid.py @@ -1,14 +1,12 @@ import logging +from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import matplotlib.pyplot as plt - +import numpy as np from numpy.typing import NDArray -from .common import Entity, bounded, Vec2, grid_pos -from .constants import WIDTH, HEIGHT, HW, HH - -from typing import List, Tuple, Union, Optional +from .common import Entity, Vec2, bounded, grid_pos +from .constants import HEIGHT, HH, HW, WIDTH def plot_grid(grid: NDArray, name: str): @@ -113,7 +111,7 @@ class GridView(): def __init__(self, grid_factor: int): self.gf: int = grid_factor - def set_scenario(self, game_state: dict): + def set_scenario(self, game_state: Dict[str, Any]): # logging.info(f'Setting scenario in GridView {game_state["grid"]}') self.fixed_grid10 = np.array(game_state['grid']) self.csize: int = int(game_state['cellSize']) diff --git a/create_move_data.ipynb b/create_move_data.ipynb index 9d44a1f..b171b7f 100644 --- a/create_move_data.ipynb +++ b/create_move_data.ipynb @@ -19,11 +19,9 @@ "source": [ "import os\n", "\n", - "from pathlib import Path\n", - "\n", "from common import GameReplay, Entity, GridView, Controls, to_entities\n", "\n", - "from typing import List\n", + "from typing import Dict, List\n", "from numpy.typing import NDArray\n", "\n", "def get_player(entities: List[Entity], index) -> Entity:\n", @@ -35,7 +33,7 @@ " return e\n", " raise Exception('Player not present: ()'.format(index))\n", "\n", - "def match_shape(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n", + "def match_shape(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n", " if len(data) == 0:\n", " return True\n", " \n", @@ -48,7 +46,7 @@ " return False\n", " return True\n", " \n", - "def should_extend_input(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n", + "def should_extend_input(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n", " if not match_shape(data, entry):\n", " return False\n", " \n", @@ -62,7 +60,7 @@ " # return False\n", " return True\n", "\n", - "def should_extend_output(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n", + "def should_extend_output(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n", " if not match_shape(data, entry):\n", " return False\n", " \n", @@ -70,7 +68,7 @@ " return False\n", " return True\n", "\n", - "def is_same_as_last(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n", + "def is_same_as_last(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n", " for k, v in entry.items():\n", " if k not in data:\n", " return False\n", @@ -78,7 +76,7 @@ " return False \n", " return True\n", " \n", - "def extend_data(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool: \n", + "def extend_data(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool: \n", " for k, v in entry.items():\n", " if k not in data:\n", " data[k] = [v]\n", @@ -87,7 +85,7 @@ " \n", " return True\n", "\n", - "def process_replay(filepath, inputs: dict[str, List[NDArray]], outputs: dict[str, List[NDArray]]):\n", + "def process_replay(filepath, inputs: Dict[str, List[NDArray]], outputs: Dict[str, List[NDArray]]):\n", " replay = GameReplay()\n", " replay.load(filepath)\n", " print('actions:', len(replay.actions))\n", @@ -150,7 +148,7 @@ "\n", "import shutil\n", "\n", - "def save_data(name: str, type: str, data: dict[str, List[NDArray]]):\n", + "def save_data(name: str, type: str, data: Dict[str, List[NDArray]]):\n", " dir_path = os.path.join('data', name, type)\n", " if os.path.exists(dir_path):\n", " shutil.rmtree(dir_path)\n", diff --git a/entity_envs/entity_base_env.py b/entity_envs/entity_base_env.py index 7dd0c66..5bf7b2a 100644 --- a/entity_envs/entity_base_env.py +++ b/entity_envs/entity_base_env.py @@ -8,28 +8,24 @@ from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP from common.entity import Entity, to_entities -from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider +from towerfall import Towerfall class TowerfallEntityEnv(Environment): def __init__(self, - towerfall: Optional[TowerfallProcess] = None, + towerfall: Optional[Towerfall] = None, record_path: Optional[str] = None, verbose: int = 0): logging.info('Initializing TowerfallEntityEnv') self.verbose = verbose - if towerfall: - self.connection = towerfall.join(timeout=5, verbose=self.verbose) - self.towerfall = towerfall - else: - self.connection, self.towerfall = TowerfallProcessProvider().join_new( - fastrun=True, - # nographics=True, - config=dict( - mode='sandbox', - level='3', - agents=[dict(type='remote', team='blue', archer='green')]), - timeout=5, - verbose=self.verbose) + self.towerfall = towerfall if towerfall else Towerfall(fastrun=True, + # nographics=True, + config=dict( + mode='sandbox', + level='3', + agents=[dict(type='remote', team='blue', archer='green')]), + timeout=5, + verbose=self.verbose) + self.connection = self.towerfall.join() self.connection.record_path = record_path self._draw_elems = [] self.is_init_sent = False @@ -53,7 +49,7 @@ def _send_reset(self): Returns: True if hard reset, False if soft reset. ''' - self.towerfall.send_reset(verbose=self.verbose) + self.towerfall.send_reset() @abstractmethod def _post_reset(self) -> Observation: @@ -174,7 +170,7 @@ def _actions_to_command(self, actions: Mapping[ActionName, Action]) -> str: def act(self, actions: Mapping[ActionName, Action]) -> Observation: command = self._actions_to_command(actions) - resp: dict[str, Any] = dict( + resp: Dict[str, Any] = dict( type='commands', command=command, id=self.state_update['id'] diff --git a/entity_envs/entity_env.py b/entity_envs/entity_env.py index e00db3f..d6356f4 100644 --- a/entity_envs/entity_env.py +++ b/entity_envs/entity_env.py @@ -1,16 +1,14 @@ -import logging import random from entity_gym.env import Observation, GlobalCategoricalActionMask -from typing import Any, Optional +from typing import Any, Dict, List, Optional import numpy as np from common.constants import HH, HW from common.entity import Entity, Vec2 from entity_envs.entity_base_env import TowerfallEntityEnv -from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider class TowerfallEntityEnvImpl(TowerfallEntityEnv): @@ -35,7 +33,7 @@ def _is_reset_valid(self) -> bool: def _send_reset(self): reset_entities = self._get_reset_entities() - self.towerfall.send_reset(reset_entities, verbose=self.verbose) + self.towerfall.send_reset(reset_entities) def _post_reset(self) -> Observation: assert self.me, 'No player found after reset' @@ -55,9 +53,9 @@ def _post_observe(self) -> Observation: arrows = list(e for e in self.entities if e['type'] == 'arrow') return self._get_obs(targets, arrows) - def _get_reset_entities(self) -> Optional[list[dict]]: + def _get_reset_entities(self) -> Optional[List[Dict[str, Any]]]: p = Vec2(160, 110) - entities: list[dict[str, Any]] = [dict(type='archer', pos=p.dict())] + entities: List[Dict[str, Any]] = [dict(type='archer', pos=p.dict())] for i in range(self.enemy_count): sign = random.randint(0, 1)*2 - 1 d = random.uniform(self.min_distance, self.max_distance) * sign @@ -68,7 +66,7 @@ def _get_reset_entities(self) -> Optional[list[dict]]: entities.append(enemy) return entities - def _update_reward(self, enemies: list[Entity]): + def _update_reward(self, enemies: List[Entity]): ''' Updates the reward and checks if the episode is done. ''' @@ -102,7 +100,7 @@ def _update_reward(self, enemies: list[Entity]): def limit(self, x: float, a: float, b: float) -> float: return x+b-a if x < a else x-b+a if x > b else x - def _get_obs(self, enemies: list[Entity], arrows: list[Entity]) -> Observation: + def _get_obs(self, enemies: List[Entity], arrows: List[Entity]) -> Observation: if not self.me: return Observation( done=self.done, diff --git a/entity_envs/predefined_envs.py b/entity_envs/predefined_envs.py index 5d012de..02ab789 100644 --- a/entity_envs/predefined_envs.py +++ b/entity_envs/predefined_envs.py @@ -1,20 +1,9 @@ -from typing import Any, Optional +from typing import Optional -from common.grid import GridView from entity_envs.entity_env import TowerfallEntityEnvImpl -from envs.connection_provider import TowerfallProcessProvider -def create_kill_enemy(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0): - towerfall_provider = TowerfallProcessProvider('default') - towerfall = towerfall_provider.get_process( - fastrun=True, - config=dict( - mode='sandbox', - level='1', - fps=90, - agents=[dict(type='remote', team='blue', archer='green')] - )) + +def create_kill_enemy(record_path: Optional[str]=None, verbose=0): return TowerfallEntityEnvImpl( - towerfall=towerfall, record_path=record_path, verbose=verbose) \ No newline at end of file diff --git a/envs/__init__.py b/envs/__init__.py index 213aea9..78b5985 100644 --- a/envs/__init__.py +++ b/envs/__init__.py @@ -4,5 +4,4 @@ from .curriculums import * from .objectives import * from .observations import * -from .connection_provider import * from .predefined_envs import * \ No newline at end of file diff --git a/envs/base_env.py b/envs/base_env.py index 15c036e..38c3893 100644 --- a/envs/base_env.py +++ b/envs/base_env.py @@ -1,17 +1,14 @@ -import json import logging - from abc import ABC, abstractmethod +from typing import Any, List, Optional, Tuple -from common import Connection, Entity, to_entities - -from .actions import TowerfallActions -from .connection_provider import TowerfallProcess - -from typing import List, Optional, Tuple, Any +from gym import Env from numpy.typing import NDArray -from gym import Env +from common import Entity, to_entities +from towerfall.towerfall import Towerfall + +from .actions import TowerfallActions class TowerfallEnv(Env, ABC): @@ -23,14 +20,14 @@ class TowerfallEnv(Env, ABC): param actions: The actions that the agent can take. If None, the default actions are used. ''' def __init__(self, - towerfall: TowerfallProcess, + towerfall: Towerfall, actions: Optional[TowerfallActions] = None, record_path: Optional[str] = None, verbose: int = 0): logging.info('Initializing TowerfallEnv') self.towerfall = towerfall self.verbose = verbose - self.connection = self.towerfall.join(timeout=5, verbose=self.verbose) + self.connection = self.towerfall.join(timeout=5) self.connection.record_path = record_path if actions: self.actions = actions @@ -58,7 +55,7 @@ def _send_reset(self): Returns: True if hard reset, False if soft reset. ''' - self.towerfall.send_reset(verbose=self.verbose) + self.towerfall.send_reset() @abstractmethod def _post_reset(self) -> Tuple[NDArray, dict]: @@ -122,7 +119,7 @@ def step(self, actions: NDArray) -> Tuple[NDArray, float, bool, object]: ''' command = self.actions._actions_to_command(actions) - resp: dict[str, Any] = dict( + resp: Dict[str, Any] = dict( type='commands', command=command, id=self.state_update['id'] diff --git a/envs/blank_env.py b/envs/blank_env.py index 67377e9..d9ed07a 100644 --- a/envs/blank_env.py +++ b/envs/blank_env.py @@ -1,21 +1,21 @@ import logging +from typing import List, Optional, Tuple from gym import spaces -from .base_env import TowerfallEnv +from towerfall.towerfall import Towerfall + from .actions import TowerfallActions -from .observations import TowerfallObservation +from .base_env import TowerfallEnv from .objectives import TowerfallObjective -from .connection_provider import TowerfallProcess - -from typing import Tuple, Optional +from .observations import TowerfallObservation class TowerfallBlankEnv(TowerfallEnv): '''A blank environment that can be customized with the addition of observations and an objective.''' def __init__(self, - towerfall: TowerfallProcess, - observations: list[TowerfallObservation], + towerfall: Towerfall, + observations: List[TowerfallObservation], objective: TowerfallObjective, actions: Optional[TowerfallActions]=None, record_path: Optional[str]=None, @@ -38,7 +38,7 @@ def _is_reset_valid(self) -> bool: def _send_reset(self): reset_entities = self.objective.get_reset_entities() - self.towerfall.send_reset(reset_entities, verbose=self.verbose) + self.towerfall.send_reset(reset_entities) def _post_reset(self) -> dict: obs_dict = {} diff --git a/envs/connection_provider.py b/envs/connection_provider.py deleted file mode 100644 index 5245c03..0000000 --- a/envs/connection_provider.py +++ /dev/null @@ -1,240 +0,0 @@ -import os -import json -import psutil -import signal -import logging -import time - -from subprocess import Popen, PIPE - -from common import Connection - -from typing import Any, Optional, Tuple - -from common.namedmutex import NamedMutex - - -_HOST = '127.0.0.1' - - -class TowerfallProcess: - ''' - Offers an interface with a Towerfall process. - - params pid: The process ID of the Towerfall process. - params port: The port that the Towerfall process is listening on. - params config: The current configuration of the Towerfall process. - ''' - def __init__(self, pid: int, port: int, fastrun: bool, nographics: bool, config: dict[str, Any] = {}): - self.pid = pid - self.port = port - self.fastrun = fastrun - self.nographics = nographics - self.config: dict[str, Any] = config - self.connections: list[Connection] = [] - - def to_dict(self) -> dict[str, Any]: - return dict( - pid=self.pid, - port=self.port, - fastrun=self.fastrun, - nographics=self.nographics, - config=self.config - ) - - def join(self, timeout: float = 2, verbose=0) -> Connection: - connection = Connection(_HOST, self.port, timeout, verbose) - connection.write_json(dict(type='join')) - resp = connection.read_json() - if resp['type'] != 'result': - raise Exception(f'Unexpected response type: {resp["type"]}') - if not resp['success']: - raise Exception(f'Failed to join process {self.pid}: {resp["message"]}') - self.connections.append(connection) - - def on_close(): - self.connections.remove(connection) - connection.on_close = on_close - return connection - - def send_reset(self, entities: Optional[list[dict]] = None, timeout: float = 2, verbose=0): - resp = self.send_request_json(dict(type='reset', entities=entities), timeout, verbose) - if resp['type'] != 'result': - raise Exception(f'Unexpected response type: {resp["type"]}') - if not resp['success']: - raise Exception(f'Failed to reset process {self.pid}: {resp["message"]}') - if verbose > 0: - logging.info(f'Successfully reset process {self.pid}') - - def send_config(self, config = None, timeout: float = 2, verbose=0): - if not config: - config = self.config - - resp = self.send_request_json(dict(type='config', config=config), timeout, verbose) - if resp['type'] != 'result': - raise Exception(f'Unexpected response type: {resp["type"]}') - if not resp['success']: - raise Exception(f'Failed to config process {self.pid}: {resp["message"]}') - logging.info(f'Successfully applied config to process {self.pid}') - self.config = config - - def send_request_json(self, obj: dict[str, Any], timeout: float = 2, verbose=0): - connection = Connection(_HOST, self.port, timeout, verbose) - connection.write_json(obj) - return connection.read_json() - - -class TowerfallProcessProvider: - ''' - Creates and manages Towerfall processes. - - params name: Name of the connection provider. Used to separate different connection providers states. - ''' - def __init__(self, name: Optional[str] = None, - # towerfall_path: str = 'C:/Users/vcanaa/towerfall/TowerFall', - towerfall_path: str = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall'): - self.towerfall_path = towerfall_path - self.towerfall_path_exe = os.path.join(self.towerfall_path, 'TowerFall.exe') - self.name = name - self.processes = [] - if self.name: - self.connection_path = os.path.join('.connection_provider', self.name) - os.makedirs(self.connection_path, exist_ok=True) - self.state_path = os.path.join(self.connection_path, 'state.json') - - if os.path.exists(self.state_path): - with open(self.state_path, 'r') as file: - for process_data in json.loads(file.read()): - try: - psutil.Process(process_data['pid']) - self.processes.append(TowerfallProcess(**process_data)) - except psutil.NoSuchProcess: - continue - self._save_state() - - self._processes_in_use = set() - - self.default_config = dict( - mode='sandbox', - level='2', - fastrun=False, - agents=[ - dict(type='remote', team='blue', archer='green') - ] - ) - - def get_process(self, fastrun: bool = False, nographics: bool = False, config = None, verbose=0, reuse: bool = True) -> TowerfallProcess: - if not config: - config = self.default_config - - selected_process = None - while not selected_process: - selected_process = None - if reuse: - # Try to find an existing process that is not in use - def is_suitable_process(process: TowerfallProcess): - if process.fastrun != fastrun: - return False - if process.nographics != nographics: - return False - if process.pid in self._processes_in_use: - return False - return True - selected_process = next((p for p in self.processes if is_suitable_process(p)), None) - - # If no process can be reused, start a new one - if not selected_process: - logging.info(f'Starting new process {self.towerfall_path_exe}') - pargs = [self.towerfall_path_exe, '--noconfig'] - if fastrun: - pargs.append('--fastrun') - if nographics: - pargs.append('--nographics') - # Multiple TowerFall.exe can't be started at the same time, due to conflict accessing Content folder. - with NamedMutex(f'TowerfallProcessProvider_{self.name}'): - process = Popen(pargs, cwd=self.towerfall_path) - port = self._get_port(process.pid) - selected_process = TowerfallProcess(process.pid, port, fastrun, nographics) - self.processes.append(selected_process) - self._save_state() - time.sleep(2) # Give some time for game to load. There is currently no way to tell if the game loaded. - - - try: - selected_process.send_config(config, verbose=verbose) - except: - os.kill(selected_process.pid, signal.SIGTERM) - selected_process = None - self._processes_in_use.add(selected_process.pid) - self._save_state() - return selected_process - - def release_process(self, process: TowerfallProcess): - self._processes_in_use.remove(process.pid) - - def join_new(self, fastrun=True, nographics=False, config = None, timeout=2, verbose=0) -> Tuple[Connection, TowerfallProcess]: - connection = None - process = None - logging.info('Create a new process and join') - while not connection or not process: - try: - process = self.get_process(fastrun, nographics, config, verbose, reuse=False) - connection = process.join(timeout, verbose) - except Exception as ex: - logging.error(f'Failed to create and join new process: {ex}') - if process: - self.kill_process(process.pid) - if connection: - connection.close() - - return connection, process - - def kill_process(self, pid): - try: - os.kill(pid, signal.SIGTERM) - except Exception as ex: - logging.error(f'Failed to kill process {pid}: {ex}') - finally: - self.processes.remove(next(p for p in self.processes if p.pid == pid)) - self._save_state() - - def close(self): - logging.info(f'Closing all processes in context {self.name}...') - for process in self.processes: - try: - os.kill(process.pid, signal.SIGTERM) - except Exception as ex: - logging.error(f'Failed to kill process {process.pid}: {ex}') - continue - - @classmethod - def close_all(cls): - logging.info('Closing all TowerFall.exe processes...') - for process in psutil.process_iter(attrs=['pid', 'name']): - # logging.info(f'Checking process {process.pid} {process.name()}') - if process.name() != 'TowerFall.exe': - continue - try: - logging.info(f'Killing process {process.pid}...') - os.kill(process.pid, signal.SIGTERM) - except Exception as ex: - logging.error(f'Failed to kill process {process.pid}: {ex}') - continue - - def _get_port(self, pid: int) -> int: - port_path = os.path.join(self.towerfall_path, 'ports', str(pid)) - tries = 0 - print(f'Waiting for port file {port_path} to be created...') - while not os.path.exists(port_path) and tries < 20: - time.sleep(0.2) - tries += 1 - with open(port_path, 'r') as file: - return int(file.readline()) - - def _save_state(self): - if self.name: - with open(self.state_path, 'w') as file: - file.write(json.dumps([p.to_dict() for p in self.processes], indent=2)) - - def _match_config(self, config1, config2): - return False diff --git a/envs/curriculums.py b/envs/curriculums.py index b8e481e..778f8c3 100644 --- a/envs/curriculums.py +++ b/envs/curriculums.py @@ -12,7 +12,7 @@ from .objectives import FollowTargetObjective from .objectives import TowerfallObjective -from typing import Tuple, Iterable, Optional +from typing import Any, Dict, List, Tuple, Iterable, Optional class TaskEncoder(json.JSONEncoder): @@ -71,7 +71,7 @@ def n_episodes(self): return 1 return len(self.start_ends) - def is_reset_valid(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity]) -> bool: + def is_reset_valid(self, state_scenario: Dict[str, Any], player: Optional[Entity], entities: List[Entity]) -> bool: if self.initialized: return True @@ -82,7 +82,7 @@ def is_reset_valid(self, state_scenario: dict, player: Optional[Entity], entitie # self.gv = GridView(5) self.gv.set_scenario(state_scenario) self.gv.update(entities, player) - self.start_ends: list[Tuple[Vec2, Vec2]] = [] + self.start_ends: List[Tuple[Vec2, Vec2]] = [] hsize = player.s / 2 for start in self._pick_all_starts(): if self.gv.is_region_collision(start - hsize, start + hsize): @@ -110,10 +110,10 @@ def is_reset_valid(self, state_scenario: dict, player: Optional[Entity], entitie self.initialized = True return False - def extend_obs_space(self, obs_space_dict: dict[str, Space]): + def extend_obs_space(self, obs_space_dict: Dict[str, Space]): self.objective.extend_obs_space(obs_space_dict) - def get_reset_entities(self) -> Optional[list[dict]]: + def get_reset_entities(self) -> Optional[List[Dict[str, Any]]]: if not self.initialized: return None self.objective.env = self.env @@ -125,11 +125,11 @@ def get_reset_entities(self) -> Optional[list[dict]]: self.start, self.end = self.start_ends[self.task_idx] return [dict(type='archer', pos=self.start.dict())] - def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity], obs_dict: dict): + def post_reset(self, state_scenario: Dict[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): target = (self.end.x, self.end.y) self.objective.post_reset(state_scenario, player, entities, obs_dict, target) - def post_step(self, player: Optional[Entity], entities: list[Entity], command: str, obs_dict: dict): + def post_step(self, player: Optional[Entity], entities: List[Entity], command: str, obs_dict: Dict[str, Any]): self.objective.post_step(player, entities, command, obs_dict) self.rew = self.objective.rew self.done = self.objective.done diff --git a/envs/kill_enemy_objective.py b/envs/kill_enemy_objective.py index e0278c7..04c987a 100644 --- a/envs/kill_enemy_objective.py +++ b/envs/kill_enemy_objective.py @@ -1,5 +1,5 @@ import random -from typing import Any, Optional +from typing import Any, Dict, List, Optional import numpy as np from gym import Space, spaces @@ -31,15 +31,15 @@ def __init__(self, self.episode_len = 0 self.obs_space = spaces.Box(low=-1, high = 1, shape=(3*self.enemy_count,), dtype=np.float32) - def extend_obs_space(self, obs_space_dict: dict[str, Space]): + def extend_obs_space(self, obs_space_dict: Dict[str, Space]): if 'targets' in obs_space_dict: raise Exception('Observation space already has \'target\'') target_space = {} obs_space_dict['targets'] = self.obs_space - def get_reset_entities(self) -> Optional[list[dict]]: + def get_reset_entities(self) -> Optional[List[Dict[str, Any]]]: p = Vec2(160, 110) - entities: list[dict[str, Any]] = [dict(type='archer', pos=p.dict())] + entities: List[Dict[str, Any]] = [dict(type='archer', pos=p.dict())] for i in range(self.enemy_count): sign = random.randint(0, 1)*2 - 1 d = random.uniform(self.min_distance, self.max_distance) * sign @@ -50,7 +50,7 @@ def get_reset_entities(self) -> Optional[list[dict]]: entities.append(enemy) return entities - def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity], obs_dict: dict): + def post_reset(self, state_scenario: Dict[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): assert player targets = list(e for e in entities if e['type'] == self.enemy_type) assert len(targets) > 0, 'No targets found' @@ -62,13 +62,13 @@ def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: l self._update_obs(player, targets, obs_dict) - def post_step(self, player: Optional[Entity], entities: list[Entity], command: str, obs_dict: dict): + def post_step(self, player: Optional[Entity], entities: List[Entity], command: str, obs_dict: Dict[str, Any]): targets = list(e for e in entities if e['id'] in self.target_ids) self._update_reward(player, targets) self.episode_len += 1 self._update_obs(player, targets, obs_dict) - def _update_reward(self, player: Optional[Entity], targets: list[Entity]): + def _update_reward(self, player: Optional[Entity], targets: List[Entity]): ''' Updates the reward and checks if the episode is done. ''' @@ -87,7 +87,7 @@ def _update_reward(self, player: Optional[Entity], targets: list[Entity]): def limit(self, x: float, a: float, b: float) -> float: return x+b-a if x < a else x-b+a if x > b else x - def _update_obs(self, player: Optional[Entity], targets: list[Entity], obs_dict: dict): + def _update_obs(self, player: Optional[Entity], targets: List[Entity], obs_dict: Dict[str, Any]): obs_target = np.zeros((3*self.enemy_count,), dtype=np.float32) if not player: obs_dict['targets'] = obs_target diff --git a/envs/objectives.py b/envs/objectives.py index a179baf..a6ea7e6 100644 --- a/envs/objectives.py +++ b/envs/objectives.py @@ -11,7 +11,7 @@ from .base_env import TowerfallEnv from .observations import TowerfallObservation -from typing import Optional, Tuple +from typing import Any, Dict, List, Mapping, Optional, Tuple from numpy.typing import NDArray @@ -21,10 +21,10 @@ def __init__(self): self.rew: float self.env: TowerfallEnv - def is_reset_valid(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity]) -> bool: + def is_reset_valid(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity]) -> bool: return True - def get_reset_entities(self) -> list[dict]: + def get_reset_entities(self) -> list[Dict[str, Any]]: '''Specifies how the environment needs to be reset.''' return [] @@ -51,12 +51,12 @@ def __init__(self, grid_view: Optional[GridView], distance: float=8, max_distanc self.rew_dc = rew_dc self.obs_space = spaces.Box(low=-2, high = 2, shape=(2,), dtype=np.float32) - def extend_obs_space(self, obs_space_dict: dict[str, Space]): + def extend_obs_space(self, obs_space_dict: Dict[str, Space]): if 'target' in obs_space_dict: raise Exception('Observation space already has \'target\'') obs_space_dict['target'] = self.obs_space - def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity], obs_dict: dict, target: Optional[Tuple[float, float]] = None): + def post_reset(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any], target: Optional[Tuple[float, float]] = None): if not player: obs_dict['target'] = self.obs_target return @@ -67,7 +67,7 @@ def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: l self._set_random_target(player) obs_dict['target'] = self.obs_target - def post_step(self, player: Optional[Entity], entities: list[Entity], command: str, obs_dict: dict): + def post_step(self, player: Optional[Entity], entities: List[Entity], command: str, obs_dict: Dict[str, Any]): self._update_reward(player) self.episode_len += 1 if player: diff --git a/envs/observations.py b/envs/observations.py index 57cadd1..ab6b8a1 100644 --- a/envs/observations.py +++ b/envs/observations.py @@ -6,7 +6,7 @@ from common import Entity, GridView, JUMP, DASH, SHOOT -from typing import Sequence, Optional, Tuple, Union +from typing import Any, Dict, List, Mapping, Sequence, Optional, Tuple, Union class TowerfallObservation(ABC): @@ -14,17 +14,17 @@ class TowerfallObservation(ABC): Base class for observations. ''' @abstractmethod - def extend_obs_space(self, obs_space_dict: dict[str, Space]): + def extend_obs_space(self, obs_space_dict: Mapping[str, Space]): '''Adds the new definitions to observations to obs_space.''' raise NotImplementedError() @abstractmethod - def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity], obs_dict: dict): + def post_reset(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): '''Hook for a gym reset call. Adds observations to obs_dict.''' raise NotImplementedError @abstractmethod - def post_step(self, player: Optional[Entity], entities: list[Entity], command: str, obs_dict: dict): + def post_step(self, player: Optional[Entity], entities: List[Entity], command: str, obs_dict: Mapping[str, Any]): '''Hook for a gym step call. Adds observations to obs_dict.''' raise NotImplementedError @@ -34,7 +34,7 @@ class PlayerObservation(TowerfallObservation): def __init__(self, exclude: Optional[Sequence[str]] = None): self.exclude = exclude - def extend_obs_space(self, obs_space_dict: dict[str, Space]): + def extend_obs_space(self, obs_space_dict: Dict[str, Space]): def try_add_obs(key, value): if self.exclude and key in self.exclude: return @@ -53,13 +53,13 @@ def try_add_obs(key, value): try_add_obs('onWall', spaces.Discrete(2)) try_add_obs('vel', spaces.Box(low=-2, high=2, shape=(2,), dtype=np.float32)) - def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity], obs_dict: dict): + def post_reset(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): self._extend_obs(player, '', obs_dict) - def post_step(self, player: Optional[Entity], entities: list[Entity], command: str, obs_dict): + def post_step(self, player: Optional[Entity], entities: List[Entity], command: str, obs_dict: Dict[str, Any]): self._extend_obs(player, command, obs_dict) - def _extend_obs(self, player: Optional[Entity], command: str, obs_dict: dict): + def _extend_obs(self, player: Optional[Entity], command: str, obs_dict: Dict[str, Any]): def try_add_obs(key, value): if self.exclude and key in self.exclude: return @@ -100,25 +100,25 @@ def __init__(self, grid_view: GridView, sight: Optional[Union[Tuple[int, int], i self.frame = 0 self.add_grid = add_grid - def extend_obs_space(self, obs_space_dict: dict[str, Space]): + def extend_obs_space(self, obs_space_dict: Dict[str, Space]): if self.add_grid: if 'grid' in obs_space_dict: raise Exception('Observation space already has \'grid\'') obs_space_dict['grid'] = self.obs_space - def post_reset(self, state_scenario: dict, player: Optional[Entity], entities: list[Entity], obs_dict: dict): + def post_reset(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): self.frame = 0 self.gv.set_scenario(state_scenario) self._update_grid(player, entities) self._extend_obs(obs_dict) - def post_step(self, player: Optional[Entity], entities: list[Entity], command: str, obs_dict: dict): + def post_step(self, player: Optional[Entity], entities: List[Entity], command: str, obs_dict: Dict[str, Any]): self._update_grid(player, entities) self._extend_obs(obs_dict) self.frame += 1 - def _update_grid(self, player: Optional[Entity], entities: list[Entity]): + def _update_grid(self, player: Optional[Entity], entities: List[Entity]): if player: self.gv.update(entities, player) else: @@ -126,6 +126,6 @@ def _update_grid(self, player: Optional[Entity], entities: list[Entity]): self.gv.update(entities, self.prev_player) self.prev_player = player - def _extend_obs(self, obs_dict: dict): + def _extend_obs(self, obs_dict: Dict[str, Any]): if self.add_grid: obs_dict['grid'] = self.gv.view(self.sight) diff --git a/envs/predefined_envs.py b/envs/predefined_envs.py index 3417a83..19c663f 100644 --- a/envs/predefined_envs.py +++ b/envs/predefined_envs.py @@ -1,19 +1,18 @@ -from typing import Any, Optional +from typing import Any, Dict, Optional from common.grid import GridView from envs.actions import TowerfallActions from envs.blank_env import TowerfallBlankEnv -from envs.connection_provider import TowerfallProcessProvider from envs.curriculums import FollowCloseTargetCurriculum from envs.kill_enemy_objective import KillEnemyObjective from envs.observations import GridObservation +from towerfall import Towerfall -def create_simple_move_env(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0): +def create_simple_move_env(configs: Dict[str, Any], record_path: Optional[str]=None, verbose=0): grid_view = GridView(grid_factor=5) objective = FollowCloseTargetCurriculum(grid_view, **configs['objective_params']) - towerfall_provider = TowerfallProcessProvider('default') - towerfall = towerfall_provider.get_process( + towerfall = Towerfall( fastrun=True, config=dict( mode='sandbox', @@ -31,10 +30,9 @@ def create_simple_move_env(configs: dict[str, Any], record_path: Optional[str]=N verbose=verbose) -def create_kill_enemy(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0): +def create_kill_enemy(configs: Dict[str, Any], record_path: Optional[str]=None, verbose=0): objective = KillEnemyObjective(**configs['objective_params']) - towerfall_provider = TowerfallProcessProvider('default') - towerfall = towerfall_provider.get_process( + towerfall = Towerfall( fastrun=True, config=dict( mode='sandbox', diff --git a/evaluate_policy.py b/evaluate_policy.py index 4e877e8..c06463d 100644 --- a/evaluate_policy.py +++ b/evaluate_policy.py @@ -1,20 +1,16 @@ import argparse import logging import os -import numpy as np -import time import json from stable_baselines3 import PPO -from stable_baselines3.common.monitor import load_results, Monitor -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.results_plotter import ts2xy, plot_results -from envs import TowerfallBlankEnv, GridObservation, PlayerObservation, FollowTargetObjective, FollowCloseTargetCurriculum +from envs import TowerfallBlankEnv, GridObservation, PlayerObservation, FollowCloseTargetCurriculum -from common import Connection, GridView +from common import GridView class NoLevelFormatter(logging.Formatter): def format(self, record): @@ -74,10 +70,10 @@ def evaluate(load_from: str): logging.info(f'Running evaluation for {last_model}') logging.info('Deterministic=False') - evaluate_policy(model, - env=env, - n_eval_episodes=50, - render=False, + evaluate_policy(model, + env=env, + n_eval_episodes=50, + render=False, deterministic=False, callback=input_blocking_callback) # logging.info('Deterministic=True') diff --git a/exper_rl_ppo_trainer_curr.py b/exper_rl_ppo_trainer_curr.py index ed6446a..b2a17ab 100644 --- a/exper_rl_ppo_trainer_curr.py +++ b/exper_rl_ppo_trainer_curr.py @@ -5,6 +5,8 @@ import time import json +from towerfall.connection import Connection + try: import wandb except ImportError: @@ -19,7 +21,7 @@ from stable_baselines3.common.callbacks import BaseCallback from envs import TowerfallBlankEnv, GridObservation, PlayerObservation, FollowTargetObjective, FollowCloseTargetCurriculum -from common import Connection, GridView +from common import GridView from typing import Any diff --git a/notebooks/train_move_data.ipynb b/notebooks/train_move_data.ipynb index 17823ed..2d6bf03 100644 --- a/notebooks/train_move_data.ipynb +++ b/notebooks/train_move_data.ipynb @@ -80,7 +80,7 @@ "from torch.utils.data import Dataset, random_split\n", "from collections import OrderedDict\n", "\n", - "from typing import Tuple, Optional\n", + "from typing import Any, Dict, List, Tuple, Optional\n", "\n", "\n", "def load_tensor(*segments):\n", @@ -88,24 +88,24 @@ " return t\n", "\n", "\n", - "def get_slice(data_dict: dict[str, th.Tensor], idx):\n", + "def get_slice(data_dict: Dict[str, th.Tensor], idx):\n", " return {k: v[idx] for k,v in data_dict.items()}\n", "\n", - "def get_avg(data_dict: dict[str, th.Tensor], keys: Optional[list[str]]):\n", + "def get_avg(data_dict: Dict[str, th.Tensor], keys: Optional[List[str]]):\n", " r = {}\n", " for k,v in data_dict.items():\n", " if keys==None or keys and k in keys:\n", " r[k] = v.mean(dim=0).item()\n", " return r\n", "\n", - "def get_std(data_dict: dict[str, th.Tensor], keys: Optional[list[str]]):\n", + "def get_std(data_dict: Dict[str, th.Tensor], keys: Optional[List[str]]):\n", " r = {}\n", " for k,v in data_dict.items():\n", " if keys==None or keys and k in keys:\n", " r[k] = v.std(dim=0).item()\n", " return r\n", "\n", - "def view(data_dict: dict, fn_dict: Optional[dict] = None) -> OrderedDict:\n", + "def view(data_dict: Dict, fn_dict: Optional[Dict[str, Any]] = None) -> OrderedDict:\n", " if not fn_dict:\n", " return OrderedDict(data_dict)\n", " r = OrderedDict()\n", diff --git a/run_simple_network_bot.py b/run_simple_network_bot.py index 69f266c..45c24a9 100644 --- a/run_simple_network_bot.py +++ b/run_simple_network_bot.py @@ -26,7 +26,7 @@ def reply(): "type":"commands", "command": ''.join(pressed) })) - + # print(json.dumps({ # "type":"commands", # "command": ''.join(pressed) diff --git a/synchronization/__init__.py b/synchronization/__init__.py new file mode 100644 index 0000000..af18f23 --- /dev/null +++ b/synchronization/__init__.py @@ -0,0 +1 @@ +from .namedmutex import * \ No newline at end of file diff --git a/common/namedmutex.py b/synchronization/namedmutex.py similarity index 100% rename from common/namedmutex.py rename to synchronization/namedmutex.py diff --git a/tests/test_connection_provider.py b/tests/test_connection_provider.py deleted file mode 100644 index 8cf5608..0000000 --- a/tests/test_connection_provider.py +++ /dev/null @@ -1,138 +0,0 @@ -import sys -sys.path.insert(0, 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall/aimod') - -import time -import timeit -import logging -import random - -from common import Connection -from envs import TowerfallProcessProvider, TowerfallProcess - -from typing import Any - -_VERBOSE = 0 -_TIMEOUT = 4 - -class NoLevelFormatter(logging.Formatter): - def format(self, record): - return record.getMessage() - -logging.basicConfig(level=logging.INFO) -logging.getLogger().handlers[0].setFormatter(NoLevelFormatter()) - - -process_provider = TowerfallProcessProvider('test') - -_starting_x = [64, 256, 128, 192] - -def get_random_command() -> str: - s = '' - p = 0.1 - keys = ['u', 'd', 'l', 'r', 'j', 'z', 's'] - for key in keys: - if random.random() < p: - s += key - return s - - -def get_config(agent_count: int) -> dict[str, Any]: - return dict( - mode='sandbox', - level='1', - fps=30, - agents=[dict(type='remote', team='blue', archer='green')]*agent_count) - - -def get_process(agent_count: int) -> TowerfallProcess: - return process_provider.get_process( - fastrun=True, - # nographics=True, - config=get_config(agent_count)) - - -def join(towerfall: TowerfallProcess, agent_count: int) -> list[Connection]: - connections = [] - for i in range(agent_count): - conn = towerfall.join(timeout=_TIMEOUT, verbose=_VERBOSE) - conn.log_cap = 100 - connections.append(conn) - return connections - - -def reset(towerfall: TowerfallProcess, agent_count: int) -> list[dict]: - y = 110 - entities = [dict(type='archer', pos=dict(x=_starting_x[i], y=y)) for i in range(agent_count)] - entities.append(dict(type='slime', pos=dict(x=160, y=y))) - towerfall.send_reset(entities, verbose=_VERBOSE) - return entities - - -def receive_init(connections: list[Connection]): - for i in range(len(connections)): - # init - state_init = connections[i].read_json() - assert state_init['type'] == 'init', state_init['type'] - connections[i].write_json(dict(type='result', success=True)) - - for i in range(len(connections)): - # scenario - state_scenario = connections[i].read_json() - assert state_scenario['type'] == 'scenario', state_scenario['type'] - connections[i].write_json(dict(type='result', success=True)) - - -def receive_update(connections: list[Connection], entities: list[dict], length: int): - now = time.time() - for j in range(length): - for i in range(len(connections)): - # update - state_update = connections[i].read_json() - assert state_update['type'] == 'update', state_update['type'] - # if j == 0: - # pos = [e['pos'] for e in state_update['entities'] if e['type'] == 'archer' and e['playerIndex']==i][0] - # diff = abs(pos['x'] - entities[i]['pos']['x']) - # assert diff < 2, f"{pos['x']} != {entities[i]['pos']['x']}, diff = {diff}" - connections[i].write_json(dict(type='commands', command=get_random_command(), id=state_update['id'])) - dt = time.time() - now - logging.info(f'fps: {length/dt:.2f}') - - -def run_many_resets(towerfall: TowerfallProcess, agent_count: int, reset_count: int): - connections = join(towerfall, agent_count) - entities = reset(towerfall, agent_count) - - receive_init(connections) - receive_update(connections, entities, length=20) - - for i in range(reset_count): - entities = reset(towerfall, agent_count) - receive_update(connections, entities, length=20) - - -def run_session(): - agent_count = 1 - reset_count = 5 - towerfall = get_process(agent_count) - run_many_resets(towerfall, agent_count, reset_count) - - # agent_count = 1 - # towerfall.send_config(get_config(agent_count), verbose=_VERBOSE) - # run_many_resets(towerfall, agent_count, reset_count) - - # agent_count = 3 - # towerfall.send_config(get_config(agent_count), verbose=_VERBOSE) - # run_many_resets(towerfall, agent_count, reset_count) - - # agent_count = 4 - # towerfall.send_config(get_config(agent_count), verbose=_VERBOSE) - # run_many_resets(towerfall, agent_count, reset_count) - - # process_provider.release_process(towerfall) - - -n_it = 1 -elapsed_time = timeit.timeit(run_session, number=n_it) / n_it -print(f'Elapsed time: {elapsed_time:.2f} s') - -# process_provider.close() \ No newline at end of file diff --git a/tests/test_entity_env.py b/tests/test_entity_env.py index 65778d5..d8fa851 100644 --- a/tests/test_entity_env.py +++ b/tests/test_entity_env.py @@ -1,35 +1,20 @@ import sys -sys.path.insert(0, 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall/aimod') - -import random -from entity_gym.env import GlobalCategoricalAction, GlobalCategoricalActionSpace -from entity_envs.entity_env import TowerfallEntityEnvImpl -from envs.connection_provider import TowerfallProcessProvider +sys.path.insert(0, '.') import logging +import random -from typing import Any +from entity_gym.env import (GlobalCategoricalAction, + GlobalCategoricalActionSpace) -class NoLevelFormatter(logging.Formatter): - def format(self, record): - return record.getMessage() +from common import logging_options +from entity_envs.entity_env import TowerfallEntityEnvImpl -logging.basicConfig(level=logging.INFO) -logging.getLogger().handlers[0].setFormatter(NoLevelFormatter()) +logging_options.set_default() def create_env() -> TowerfallEntityEnvImpl: - towerfall_provider = TowerfallProcessProvider('test-entity-env') - towerfall = towerfall_provider.get_process( - fastrun=True, - config=dict( - mode='sandbox', - level='2', - fps=90, - agents=[dict(type='remote', team='blue', archer='green')] - ), verbose=1) - env = TowerfallEntityEnvImpl( - verbose=0) + env = TowerfallEntityEnvImpl(verbose=0) return env diff --git a/tests/test_env.py b/tests/test_env.py index d33b2ac..4ef36c2 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,26 +1,21 @@ import sys -sys.path.insert(0, 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall/aimod') -import logging - -from envs import TowerfallBlankEnv, FollowCloseTargetCurriculum, GridObservation, PlayerObservation, TowerfallProcessProvider - -from common import GridView +sys.path.insert(0, '.') +import logging from typing import Any -class NoLevelFormatter(logging.Formatter): - def format(self, record): - return record.getMessage() +from common import GridView, logging_options +from envs import (FollowCloseTargetCurriculum, GridObservation, + PlayerObservation, TowerfallBlankEnv) +from towerfall import Towerfall -logging.basicConfig(level=logging.INFO) -logging.getLogger().handlers[0].setFormatter(NoLevelFormatter()) +logging_options.set_default() def create_env(configs) -> TowerfallBlankEnv: grid_view = GridView(grid_factor=5) objective = FollowCloseTargetCurriculum(grid_view, **configs['objective_params']) - towerfall_provider = TowerfallProcessProvider('default') - towerfall = towerfall_provider.get_process(config=dict( + towerfall = Towerfall(config=dict( mode='sandbox', level='2', agents=[dict(type='remote', team='blue', archer='green')] @@ -32,7 +27,6 @@ def create_env(configs) -> TowerfallBlankEnv: PlayerObservation() ], objective=objective, - actions= verbose=1) # check_env(env) return env diff --git a/tests/test_grid.py b/tests/test_grid.py index 2f9cdde..420d3b8 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,11 +1,14 @@ import sys -sys.path.insert(0, 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall/aimod') -import numpy as np -from common import crop_grid, Vec2 +sys.path.insert(0, '.') + +from typing import List +import numpy as np from numpy.typing import NDArray +from common import Vec2, crop_grid + m = 8 n = 6 grid = np.arange(m*n).reshape((m, n)) @@ -23,7 +26,7 @@ def print_crop(g): print(''.join(s)) -def test_combinations(bl: Vec2, tr: Vec2, grid: NDArray, expected: list): +def test_combinations(bl: Vec2, tr: Vec2, grid: NDArray, expected: List[List[int]]): for dx in range(-320, 321, 320): for dy in range(-240, 241, 240): d = Vec2(dx, dy) diff --git a/towerfall/__init__.py b/towerfall/__init__.py new file mode 100644 index 0000000..5470b55 --- /dev/null +++ b/towerfall/__init__.py @@ -0,0 +1,4 @@ +from .connection import * +from .towerfall import * + +__all__ = ['Connection', 'Towerfall'] \ No newline at end of file diff --git a/common/connection.py b/towerfall/connection.py similarity index 57% rename from common/connection.py rename to towerfall/connection.py index 2bdc10e..2a214ab 100644 --- a/common/connection.py +++ b/towerfall/connection.py @@ -1,17 +1,24 @@ -import socket import json - import logging - -from typing import Callable - -from typing import Any +import socket +from typing import Any, Callable, Mapping _BYTE_ORDER = 'big' _ENCODING = 'ascii' +_LOCALHOST = '127.0.0.1' class Connection: - def __init__(self, ip: str, port: int, timeout: float = 0, verbose=0, log_cap=50, record_path=None): + ''' + Connection to a Towerfall server. It is used to send and receive messages. + + params port: Port of the server. + params ip: Ip address of the server. + params timeout: Timeout for the socket. + params verbose: Verbosity level. 0: no logging, 1: much logging. + params log_cap: Maximum number of characters to log. + params record_path: Path to a file to record the messages sent and received. + ''' + def __init__(self, port: int, ip: str = _LOCALHOST, timeout: float = 0, verbose=0, log_cap=50, record_path=None): self.verbose = verbose self.log_cap = log_cap self.record_path = record_path @@ -26,6 +33,9 @@ def __del__(self): self.close() def close(self): + ''' + Closes the socket. + ''' if hasattr(self, '_socket'): if self.verbose > 0: logging.info('Closing socket') @@ -34,10 +44,13 @@ def close(self): if hasattr(self, 'on_close'): self.on_close() - def write(self, msg): + def write(self, msg: str): + ''' + Writes a new message following the game's protocol. + ''' size = len(msg) if self.verbose > 0: - logging.info('Writing: %s %s', size, self.cap(msg)) + logging.info('Writing: %s %s', size, self._cap(msg)) self._socket.sendall(size.to_bytes(2, byteorder=_BYTE_ORDER)) self._socket.sendall(msg.encode(_ENCODING)) @@ -45,7 +58,10 @@ def write(self, msg): with open(self.record_path, 'a') as file: file.write(msg + '\n') - def read(self): + def read(self) -> str: + ''' + Reads a message following the game's protocol. + ''' try: header: bytes = self._socket.recv(2) size = int.from_bytes(header, _BYTE_ORDER) @@ -54,7 +70,7 @@ def read(self): payload = self._socket.recv(size) resp = payload.decode(_ENCODING) if self.verbose > 0: - logging.info('Read: %s', self.cap(resp)) + logging.info('Read: %s', self._cap(resp)) if self.record_path: with open(self.record_path, 'a') as file: file.write(resp + '\n') @@ -63,12 +79,17 @@ def read(self): logging.error(f'Socket timeout {self._socket.getsockname()}') raise ex - - def read_json(self): + def read_json(self) -> Mapping[str, Any]: + ''' + Reads a message and parses it to json. + ''' return json.loads(self.read()) - def write_json(self, obj: dict[str, Any]): + def write_json(self, obj: Mapping[str, Any]): + ''' + Convert the object to json and writes it. + ''' self.write(json.dumps(obj)) - def cap(self, value: str) -> str: + def _cap(self, value: str) -> str: return value[:self.log_cap] + '...' if len(value) > self.log_cap else value diff --git a/towerfall/towerfall.py b/towerfall/towerfall.py new file mode 100644 index 0000000..c902317 --- /dev/null +++ b/towerfall/towerfall.py @@ -0,0 +1,210 @@ +import json +import logging +import os +import signal +import time +from io import TextIOWrapper +from typing import Any, Callable, Dict, List, Mapping, Optional + +import psutil +from psutil import Popen + +from .connection import Connection + +class TowerfallError(Exception): + pass + +class Towerfall: + ''' + Interfaces a Towerfall process. + + params fastrun: Whether to run the Towerfall process in fast mode. This allows more than 60 fps. + params nographics: Whether to run the Towerfall process without graphics. This might reach higher fps. + params config: The current configuration of the Towerfall process. + params pool_name: The name of the pool to use. Using different pools among clients make sure they will not compete for the same game instances. + params towerfall_path: The path to the Towerfall.exe. + params timeout: The timeout for the management connections. + params verbose: The verbosity level. 0: no logging, 1: much logging. + ''' + def __init__(self, + fastrun: bool = True, + nographics: bool = False, + config: Mapping[str, Any] = {}, + pool_name: str = 'default', + towerfall_path: str = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall', + timeout: float = 2, + verbose: int = 0): + self.fastrun = fastrun + self.nographics = nographics + self.config: Mapping[str, Any] = config + self.towerfall_path = towerfall_path + self.towerfall_path_exe = os.path.join(self.towerfall_path, 'TowerFall.exe') + self.pool_name = pool_name + self.pool_path = os.path.join(self.towerfall_path, 'pools', self.pool_name) + self.timeout = timeout + self.verbose = verbose + tries = 0 + while True: + self.port = self._attain_game_port() + + try: + self.open_connection = Connection(self.port, timeout=timeout, verbose=verbose) + self.send_config(config) + break + except TowerfallError: + if tries > 3: + raise TowerfallError('Could not config a Towerfall process.') + tries += 1 + + def join(self, timeout: float = 2) -> Connection: + ''' + Joins a towerfall game. + + params timeout: Timeout in seconds to wait for a response. The same timeout will be used on calls to get the observations. + + returns: A connection to a Towerfall game. This should be used by the agent to interact with the game. + ''' + connection = Connection(self.port, timeout=timeout, verbose=self.verbose) + connection.write_json(dict(type='join')) + response = connection.read_json() + if response['type'] != 'result': + raise TowerfallError(f'Unexpected response type: {response["type"]}') + if not response['success']: + raise TowerfallError(f'Failed to join the game. Port: {self.port}, Response: {response["message"]}') + self._try_log(logging.info, f'Successfully joined the game. Port: {self.port}') + return connection + + def send_reset(self, entities: Optional[List[Dict[str, Any]]] = None): + ''' + Sends a game reset. This will recreate the entities in the game in the same scenario. To change the scenario, use send_config. + + params entities: The entities to reset. If None, the entities specified in the last reset will be used. + ''' + + response = self.send_request_json(dict(type='reset', entities=entities)) + if response['type'] != 'result': + raise TowerfallError(f'Unexpected response type: {response["type"]}') + if not response['success']: + raise TowerfallError(f'Failed to reset the game. Port: {self.port}, Response: {response["message"]}') + self._try_log(logging.info, f'Successfully reset the game. Port: {self.port}') + + def send_config(self, config = None): + ''' + Sendns a game configuration. This will restart the session of the game in the specified scenario and specified number of agents. + + params config: The configuration to send. If None, the configuration specified in the last config will be used. + ''' + if config: + self.config = config + else: + config = self.config + + response = self.send_request_json(dict(type='config', config=config)) + if response['type'] != 'result': + raise TowerfallError(f'Unexpected response type: {response["type"]}') + if not response['success']: + raise TowerfallError(f'Failed to configure the game. Port: {self.port}, Response: {response["message"]}') + self.config = config + + def send_request_json(self, obj: Mapping[str, Any]): + self.open_connection.write_json(obj) + return self.open_connection.read_json() + + @classmethod + def close_all(cls): + ''' + Closes all Towerfall processes. + ''' + logging.info('Closing all TowerFall.exe processes...') + for process in psutil.process_iter(attrs=['pid', 'name']): + # logging.info(f'Checking process {process.pid} {process.name()}') + if process.name() != 'TowerFall.exe': + continue + try: + logging.info(f'Killing process {process.pid}...') + os.kill(process.pid, signal.SIGTERM) + except Exception as ex: + logging.error(f'Failed to kill process {process.pid}: {ex}') + continue + + def close(self): + ''' + Close the management connection. This will free the Towerfall process to be used by other clients. + ''' + self.open_connection.close() + + def _attain_game_port(self) -> int: + # with self._get_pool_mutex(): + metadata = self._find_compatible_metadata() + + if not metadata: + self._try_log(logging.info, f'Starting new process from {self.towerfall_path_exe}.') + pargs = [self.towerfall_path_exe, '--noconfig'] + if self.fastrun: + pargs.append('--fastrun') + if self.nographics: + pargs.append('--nographics') + + Popen(pargs, cwd=self.towerfall_path) + + tries = 0 + self._try_log(logging.info, f'Waiting for available process.') + while not metadata and tries < 10: + time.sleep(2) + metadata = self._find_compatible_metadata() + tries += 1 + if not metadata: + raise TowerfallError('Could not find or create a Towerfall process.') + + return metadata['port'] + + def _find_compatible_metadata(self) -> Optional[Mapping[str, Any]]: + if not os.path.exists(self.pool_path): + return None + for file_name in os.listdir(self.pool_path): + try: + pid = int(file_name) + psutil.Process(pid) + except (ValueError, psutil.NoSuchProcess): + os.remove(os.path.join(self.pool_path, file_name)) + continue + with open(os.path.join(self.pool_path, file_name), 'r') as file: + try: + metadata = Towerfall._load_metadata(file) + except (ValueError, json.JSONDecodeError, FileNotFoundError) as ex: + self._try_log(logging.warning, f'Invalid metadata file {file_name}. Exception: {ex}') + continue + if self._is_compatible_metadata(metadata): + return metadata + return None + + @staticmethod + def _load_metadata(file: TextIOWrapper) -> Mapping[str, Any]: + metadata = json.load(file) + if 'port' not in metadata: + raise ValueError('Port not found in metadata.') + try: + metadata['port'] = int(metadata['port']) + except ValueError: + raise ValueError(f'Port is not an integer. Port: {metadata["port"]}') + + if 'fastrun' not in metadata: + metadata['fastrun'] = False + if 'nographics' not in metadata: + metadata['nographics'] = False + return metadata + + def _is_compatible_metadata(self, metadata: Mapping[str, Any]) -> bool: + if metadata['fastrun'] != self.fastrun: + return False + if metadata['nographics'] != self.nographics: + return False + return True + + # def _get_pool_mutex(self) -> NamedMutex: + # return NamedMutex(f'Towerfall_pool_{self.pool_name}') + + def _try_log(self, log_fn: Callable[[str], None], message: str): + if self.verbose > 0: + log_fn(message) + diff --git a/train_kill_enemy_with_entity_env.py b/train_kill_enemy_with_entity_env.py index 2d80ad1..773647c 100644 --- a/train_kill_enemy_with_entity_env.py +++ b/train_kill_enemy_with_entity_env.py @@ -4,7 +4,7 @@ from common import logging_options from entity_envs.entity_env import TowerfallEntityEnvImpl -from envs.connection_provider import TowerfallProcessProvider +from towerfall import Towerfall logging_options.set_default() @@ -16,7 +16,7 @@ def main(state_manager: hyperstate.StateManager) -> None: train(state_manager=state_manager, env=TowerfallEntityEnvImpl) finally: logging.info('Closing all Towerfall processes') - TowerfallProcessProvider.close_all() + Towerfall.close_all() if __name__ == "__main__": diff --git a/trainer/trainer.py b/trainer/trainer.py index 5e03086..eecbe65 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -1,13 +1,13 @@ import json import logging import os -from typing import Any, Callable, Optional, Tuple -from gym import Env +from typing import Any, Callable, Dict, Optional, Tuple +import yaml +from gym import Env from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor -import yaml from envs.blank_env import TowerfallBlankEnv @@ -19,7 +19,7 @@ class Trainer: def __init__(self, export_wandb: bool = True): self.export_wandb = export_wandb - def init_wandb(self, configs: dict[str, Any], project_name: str, trial_name: str): + def init_wandb(self, configs: Dict[str, Any], project_name: str, trial_name: str): from wandb.wandb_run import Run import wandb @@ -76,7 +76,7 @@ def load_from_trial(self, project_name: str, trial_name: str, model_name: Option def get_trial_path(self, project_name: str, trial_name: str): return f'tmp/{project_name}/{trial_name}' - def _train_model(self, model, env: TowerfallBlankEnv, total_steps: int, configs: dict[str, Any], project_name: str, trial_name: str) -> float: + def _train_model(self, model, env: TowerfallBlankEnv, total_steps: int, configs: Dict[str, Any], project_name: str, trial_name: str) -> float: trial_path = self.get_trial_path(project_name, trial_name) os.makedirs(trial_path, exist_ok=True) @@ -113,7 +113,7 @@ def _train_model(self, model, env: TowerfallBlankEnv, total_steps: int, configs: self.run.finish() return train_callback.best_mean_reward - def train(self, env: TowerfallBlankEnv, total_steps: int, configs: dict[str, Any], project_name: str, trial_name: str): + def train(self, env: TowerfallBlankEnv, total_steps: int, configs: Dict[str, Any], project_name: str, trial_name: str): trial_path = self.get_trial_path(project_name, trial_name) logging.info(f'Creating Monitor in {trial_path}') monitored_env = Monitor(env, os.path.join(trial_path, 'monitor')) @@ -127,14 +127,14 @@ def train(self, env: TowerfallBlankEnv, total_steps: int, configs: dict[str, Any return self._train_model(model, env, total_steps, configs, project_name, trial_name) def fork_training(self, - env: TowerfallBlankEnv, - total_steps: int, - configs: dict[str, Any], - project_name: str, - trial_name: str, - load_project_name: str, - load_trial_name: str, - load_model_name: str): + env: TowerfallBlankEnv, + total_steps: int, + configs: Dict[str, Any], + project_name: str, + trial_name: str, + load_project_name: str, + load_trial_name: str, + load_model_name: str): trial_path = self.get_trial_path(project_name, trial_name) logging.info(f'Creating Monitor in {trial_path}') monitored_env = Monitor(env, os.path.join(trial_path, 'monitor')) @@ -142,12 +142,12 @@ def fork_training(self, model, _ = self.load_from_trial(load_project_name, load_trial_name, load_model_name, monitored_env) self._train_model(model, env, total_steps, configs, project_name, trial_name) - def evaluate_model(self, env_fn: Callable[[dict[str, Any]], Env], n_episodes: int, project_name: str, trial_name: str, model_name: str): + def evaluate_model(self, env_fn: Callable[[Dict[str, Any]], Env], n_episodes: int, project_name: str, trial_name: str, model_name: str): model, configs = self.load_from_trial(project_name, trial_name, model_name) env = env_fn(configs) evaluate_policy(model, env=env, n_eval_episodes=n_episodes, render=False, deterministic=False) - def evaluate_all_models(self, env_fn: Callable[[dict[str, Any]], Env], n_episodes: int, project_name: str, trial_name: str): + def evaluate_all_models(self, env_fn: Callable[[Dict[str, Any]], Env], n_episodes: int, project_name: str, trial_name: str): trial_path = self.get_trial_path(project_name, trial_name) logging.info(f'Loading experiment from {trial_path}') with open(os.path.join(trial_path, 'hparams.json'), 'r') as file: