Skip to content

Commit

Permalink
New towerfall client
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanaa committed May 19, 2023
1 parent dd8cd2b commit e4b2fcb
Show file tree
Hide file tree
Showing 37 changed files with 420 additions and 616 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
__pycache__
.ipynb_checkpoints/

.connection_provider/
.vscode/

config*.json
Expand Down
1 change: 1 addition & 0 deletions FollowCloseTargetCurriculum_episodes_20.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[[{"x": 160, "y": 110, "__Vec2__": true}, {"x": 140, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 144.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 148.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 152.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 156.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 160.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 164.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 168.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 172.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 176.0, "y": 130, "__Vec2__": true}], [{"x": 160, "y": 110, "__Vec2__": true}, {"x": 180.0, "y": 130, "__Vec2__": true}]]
6 changes: 3 additions & 3 deletions bots/botquest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .bot import Bot

from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

_HOST = "127.0.0.1"
_PORT = 12024
Expand Down Expand Up @@ -44,14 +44,14 @@ def update(self):
self.handle_update(game_state)


def handle_init(self, state: dict):
def handle_init(self, state: Dict[str, Any]):
logging.info("handle_init")
self.state_init = state
random.seed(state['index'])
self._connection.write('.')


def handle_scenario(self, state: dict):
def handle_scenario(self, state: Dict[str, Any]):
logging.info("handle_scenario")
self.state_scenario = state
self.gv.set_scenario(state)
Expand Down
4 changes: 2 additions & 2 deletions bots/botrecorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def update(self):
self.handle_update(game_state)


def handle_init(self, state: dict):
def handle_init(self, state: Dict[str, Any]):
logging.info("handle_init")
if hasattr(self, 'replay'):
dir_path = Path(os.path.join('replays', REPLAY_NAME))
Expand All @@ -88,7 +88,7 @@ def handle_init(self, state: dict):
self.connection.write('.')


def handle_scenario(self, state: dict):
def handle_scenario(self, state: Dict[str, Any]):
logging.info("handle_scenario")
self.replay.handle_scenario(state)
self.stateScenario = state
Expand Down
2 changes: 0 additions & 2 deletions common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from .common import *
from .connection import *
from .constants import *
from .controls import *
from .entity import *
from .gamereplay import *
from .grid import *
from .logging_options import *
from .namedmutex import *
from .pathing import *
17 changes: 6 additions & 11 deletions common/controls.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import sys
import json
import logging
import numpy as np

from math import atan2, pi
from threading import Thread, Lock

from pyjoystick.sdl2 import Key, Joystick, run_event_loop

from common import reply, Vec2

from .connection import Connection

from threading import Lock, Thread
from typing import Optional

import numpy as np
from numpy.typing import NDArray
from pyjoystick.sdl2 import Joystick, Key, run_event_loop

from common import Vec2, reply
from towerfall import Connection

pi8 = pi / 8

Expand Down
6 changes: 3 additions & 3 deletions common/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class Entity:
def __init__(self, e: dict):
def __init__(self, e: Dict[str, Any]):
self.p: Vec2 = vec2_from_dict(e['pos'])
self.v: Vec2 = vec2_from_dict(e['vel'])
self.s: Vec2 = vec2_from_dict(e['size'])
Expand Down Expand Up @@ -108,15 +108,15 @@ def div(self, f: float):
self.y /= f


def vec2_from_dict(p: dict) -> Vec2:
def vec2_from_dict(p: Dict[str, Any]) -> Vec2:
try:
return Vec2(p['x'], p['y'])
except KeyError:
sys.stderr.write(str(p))
raise


def to_entities(entities: List[dict]) -> List[Entity]:
def to_entities(entities: List[Dict[str, Any]]) -> List[Entity]:
result = []
for e in entities:
result.append(Entity(e))
Expand Down
8 changes: 4 additions & 4 deletions common/gamereplay.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json

from typing import List
from typing import Any, Dict, List

class GameReplay:
def __init__(self):
self.state_init: dict
self.state_scenario: dict
self.state_update: List[dict] = []
self.state_init: Dict[str, Any]
self.state_scenario: Dict[str, Any]
self.state_update: List[Dict[str, Any]] = []
self.actions: List[str] = []

def handle_init(self, state):
Expand Down
12 changes: 5 additions & 7 deletions common/grid.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import matplotlib.pyplot as plt

import numpy as np
from numpy.typing import NDArray

from .common import Entity, bounded, Vec2, grid_pos
from .constants import WIDTH, HEIGHT, HW, HH

from typing import List, Tuple, Union, Optional
from .common import Entity, Vec2, bounded, grid_pos
from .constants import HEIGHT, HH, HW, WIDTH


def plot_grid(grid: NDArray, name: str):
Expand Down Expand Up @@ -113,7 +111,7 @@ class GridView():
def __init__(self, grid_factor: int):
self.gf: int = grid_factor

def set_scenario(self, game_state: dict):
def set_scenario(self, game_state: Dict[str, Any]):
# logging.info(f'Setting scenario in GridView {game_state["grid"]}')
self.fixed_grid10 = np.array(game_state['grid'])
self.csize: int = int(game_state['cellSize'])
Expand Down
18 changes: 8 additions & 10 deletions create_move_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
"source": [
"import os\n",
"\n",
"from pathlib import Path\n",
"\n",
"from common import GameReplay, Entity, GridView, Controls, to_entities\n",
"\n",
"from typing import List\n",
"from typing import Dict, List\n",
"from numpy.typing import NDArray\n",
"\n",
"def get_player(entities: List[Entity], index) -> Entity:\n",
Expand All @@ -35,7 +33,7 @@
" return e\n",
" raise Exception('Player not present: ()'.format(index))\n",
"\n",
"def match_shape(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n",
"def match_shape(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n",
" if len(data) == 0:\n",
" return True\n",
" \n",
Expand All @@ -48,7 +46,7 @@
" return False\n",
" return True\n",
" \n",
"def should_extend_input(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n",
"def should_extend_input(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n",
" if not match_shape(data, entry):\n",
" return False\n",
" \n",
Expand All @@ -62,23 +60,23 @@
" # return False\n",
" return True\n",
"\n",
"def should_extend_output(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n",
"def should_extend_output(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n",
" if not match_shape(data, entry):\n",
" return False\n",
" \n",
" if np.linalg.norm(entry['dpos']) > 50:\n",
" return False\n",
" return True\n",
"\n",
"def is_same_as_last(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:\n",
"def is_same_as_last(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool:\n",
" for k, v in entry.items():\n",
" if k not in data:\n",
" return False\n",
" if not np.array_equal(data[k][-1], v):\n",
" return False \n",
" return True\n",
" \n",
"def extend_data(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool: \n",
"def extend_data(data: Dict[str, List[NDArray]], entry: Dict[str, NDArray]) -> bool: \n",
" for k, v in entry.items():\n",
" if k not in data:\n",
" data[k] = [v]\n",
Expand All @@ -87,7 +85,7 @@
" \n",
" return True\n",
"\n",
"def process_replay(filepath, inputs: dict[str, List[NDArray]], outputs: dict[str, List[NDArray]]):\n",
"def process_replay(filepath, inputs: Dict[str, List[NDArray]], outputs: Dict[str, List[NDArray]]):\n",
" replay = GameReplay()\n",
" replay.load(filepath)\n",
" print('actions:', len(replay.actions))\n",
Expand Down Expand Up @@ -150,7 +148,7 @@
"\n",
"import shutil\n",
"\n",
"def save_data(name: str, type: str, data: dict[str, List[NDArray]]):\n",
"def save_data(name: str, type: str, data: Dict[str, List[NDArray]]):\n",
" dir_path = os.path.join('data', name, type)\n",
" if os.path.exists(dir_path):\n",
" shutil.rmtree(dir_path)\n",
Expand Down
30 changes: 13 additions & 17 deletions entity_envs/entity_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,24 @@
from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP
from common.entity import Entity, to_entities

from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider
from towerfall import Towerfall

class TowerfallEntityEnv(Environment):
def __init__(self,
towerfall: Optional[TowerfallProcess] = None,
towerfall: Optional[Towerfall] = None,
record_path: Optional[str] = None,
verbose: int = 0):
logging.info('Initializing TowerfallEntityEnv')
self.verbose = verbose
if towerfall:
self.connection = towerfall.join(timeout=5, verbose=self.verbose)
self.towerfall = towerfall
else:
self.connection, self.towerfall = TowerfallProcessProvider().join_new(
fastrun=True,
# nographics=True,
config=dict(
mode='sandbox',
level='3',
agents=[dict(type='remote', team='blue', archer='green')]),
timeout=5,
verbose=self.verbose)
self.towerfall = towerfall if towerfall else Towerfall(fastrun=True,
# nographics=True,
config=dict(
mode='sandbox',
level='3',
agents=[dict(type='remote', team='blue', archer='green')]),
timeout=5,
verbose=self.verbose)
self.connection = self.towerfall.join()
self.connection.record_path = record_path
self._draw_elems = []
self.is_init_sent = False
Expand All @@ -53,7 +49,7 @@ def _send_reset(self):
Returns:
True if hard reset, False if soft reset.
'''
self.towerfall.send_reset(verbose=self.verbose)
self.towerfall.send_reset()

@abstractmethod
def _post_reset(self) -> Observation:
Expand Down Expand Up @@ -174,7 +170,7 @@ def _actions_to_command(self, actions: Mapping[ActionName, Action]) -> str:
def act(self, actions: Mapping[ActionName, Action]) -> Observation:
command = self._actions_to_command(actions)

resp: dict[str, Any] = dict(
resp: Dict[str, Any] = dict(
type='commands',
command=command,
id=self.state_update['id']
Expand Down
14 changes: 6 additions & 8 deletions entity_envs/entity_env.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import logging
import random

from entity_gym.env import Observation, GlobalCategoricalActionMask

from typing import Any, Optional
from typing import Any, Dict, List, Optional

import numpy as np
from common.constants import HH, HW
from common.entity import Entity, Vec2
from entity_envs.entity_base_env import TowerfallEntityEnv

from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider


class TowerfallEntityEnvImpl(TowerfallEntityEnv):
Expand All @@ -35,7 +33,7 @@ def _is_reset_valid(self) -> bool:

def _send_reset(self):
reset_entities = self._get_reset_entities()
self.towerfall.send_reset(reset_entities, verbose=self.verbose)
self.towerfall.send_reset(reset_entities)

def _post_reset(self) -> Observation:
assert self.me, 'No player found after reset'
Expand All @@ -55,9 +53,9 @@ def _post_observe(self) -> Observation:
arrows = list(e for e in self.entities if e['type'] == 'arrow')
return self._get_obs(targets, arrows)

def _get_reset_entities(self) -> Optional[list[dict]]:
def _get_reset_entities(self) -> Optional[List[Dict[str, Any]]]:
p = Vec2(160, 110)
entities: list[dict[str, Any]] = [dict(type='archer', pos=p.dict())]
entities: List[Dict[str, Any]] = [dict(type='archer', pos=p.dict())]
for i in range(self.enemy_count):
sign = random.randint(0, 1)*2 - 1
d = random.uniform(self.min_distance, self.max_distance) * sign
Expand All @@ -68,7 +66,7 @@ def _get_reset_entities(self) -> Optional[list[dict]]:
entities.append(enemy)
return entities

def _update_reward(self, enemies: list[Entity]):
def _update_reward(self, enemies: List[Entity]):
'''
Updates the reward and checks if the episode is done.
'''
Expand Down Expand Up @@ -102,7 +100,7 @@ def _update_reward(self, enemies: list[Entity]):
def limit(self, x: float, a: float, b: float) -> float:
return x+b-a if x < a else x-b+a if x > b else x

def _get_obs(self, enemies: list[Entity], arrows: list[Entity]) -> Observation:
def _get_obs(self, enemies: List[Entity], arrows: List[Entity]) -> Observation:
if not self.me:
return Observation(
done=self.done,
Expand Down
17 changes: 3 additions & 14 deletions entity_envs/predefined_envs.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
from typing import Any, Optional
from typing import Optional

from common.grid import GridView
from entity_envs.entity_env import TowerfallEntityEnvImpl
from envs.connection_provider import TowerfallProcessProvider

def create_kill_enemy(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0):
towerfall_provider = TowerfallProcessProvider('default')
towerfall = towerfall_provider.get_process(
fastrun=True,
config=dict(
mode='sandbox',
level='1',
fps=90,
agents=[dict(type='remote', team='blue', archer='green')]
))

def create_kill_enemy(record_path: Optional[str]=None, verbose=0):
return TowerfallEntityEnvImpl(
towerfall=towerfall,
record_path=record_path,
verbose=verbose)
1 change: 0 additions & 1 deletion envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@
from .curriculums import *
from .objectives import *
from .observations import *
from .connection_provider import *
from .predefined_envs import *
Loading

0 comments on commit e4b2fcb

Please sign in to comment.