-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
467 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ config*.json | |
|
||
data/ | ||
ignore/ | ||
checkpoints/ | ||
logs/ | ||
models/ | ||
optuna/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .entity_base_env import * | ||
from .entity_env import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from abc import abstractmethod | ||
import logging | ||
from typing import Any, List, Mapping, Optional, Dict | ||
|
||
from entity_gym.env import GlobalCategoricalAction, Environment, Observation, Action, ActionName, ObsSpace, ActionSpace, GlobalCategoricalActionSpace | ||
from entity_gym.env import Entity as EntityGym | ||
|
||
from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP | ||
from common.entity import Entity, to_entities | ||
|
||
from envs.connection_provider import TowerfallProcess | ||
|
||
class TowerfallEntityEnv(Environment): | ||
def __init__(self, | ||
towerfall: TowerfallProcess, | ||
record_path: Optional[str] = None, | ||
verbose: int = 0): | ||
logging.info('Initializing TowerfallEntityEnv') | ||
self.towerfall = towerfall | ||
self.verbose = verbose | ||
self.connection = self.towerfall.join(timeout=5, verbose=self.verbose) | ||
self.connection.record_path = record_path | ||
self._draw_elems = [] | ||
self.is_init_sent = False | ||
|
||
logging.info('Initialized TowerfallEnv') | ||
|
||
def _is_reset_valid(self) -> bool: | ||
''' | ||
Use this to make check if the initiallization is valid. | ||
This is useful to collect information about the environment to programmatically construct a sequence of tasks, then | ||
reset the environment again with the proper reseet instructions. | ||
Returns: | ||
True if the reset is valid, False otherwise, then the environment will be reset again. | ||
''' | ||
return True | ||
|
||
def _send_reset(self): | ||
''' | ||
Sends the reset instruction to the game. Overwrite this to change the starting conditions. | ||
Returns: | ||
True if hard reset, False if soft reset. | ||
''' | ||
self.towerfall.send_reset(verbose=self.verbose) | ||
|
||
@abstractmethod | ||
def _post_reset(self) -> Observation: | ||
''' | ||
Hook for a gym reset call. Subclass should populate and return the same as a reset in gym API. | ||
Returns: | ||
A tuple of (observation, info) | ||
''' | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _post_observe(self) -> Observation: | ||
raise NotImplementedError | ||
|
||
def draws(self, draw_elem): | ||
''' | ||
Draws an element on the screen. This is useful for debugging. | ||
''' | ||
self._draw_elems.append(draw_elem) | ||
|
||
def reset(self) -> Observation: | ||
while True: | ||
self._send_reset() | ||
if not self.is_init_sent: | ||
state_init = self.connection.read_json() | ||
assert state_init['type'] == 'init', state_init['type'] | ||
self.index = state_init['index'] | ||
self.connection.write_json(dict(type='result', success=True)) | ||
|
||
self.state_scenario = self.connection.read_json() | ||
assert self.state_scenario['type'] == 'scenario', self.state_scenario['type'] | ||
self.connection.write_json(dict(type='result', success=True)) | ||
self.is_init_sent = True | ||
else: | ||
self.connection.write_json(dict(type='commands', command="", id=self.state_update['id'])) | ||
|
||
self.frame = 0 | ||
self.state_update = self.connection.read_json() | ||
assert self.state_update['type'] == 'update', self.state_update['type'] | ||
self.entities = to_entities(self.state_update['entities']) | ||
self.me = self._get_own_archer(self.entities) | ||
if self._is_reset_valid(): | ||
break | ||
|
||
return self._post_reset() | ||
|
||
@classmethod | ||
def obs_space(cls) -> ObsSpace: | ||
return ObsSpace( | ||
global_features=["y", "arrow_count"], | ||
entities={ | ||
"enemies": EntityGym(features=["x", "y", "facing"]), | ||
"arrows": EntityGym(features=["x", "y", "stuck"]), | ||
}) | ||
|
||
@classmethod | ||
def action_space(cls) -> Dict[ActionName, ActionSpace]: | ||
return { | ||
"hor": GlobalCategoricalActionSpace( | ||
[".", "l", "r"], | ||
), | ||
"ver": GlobalCategoricalActionSpace( | ||
[".", "d", "u"], | ||
), | ||
"shoot": GlobalCategoricalActionSpace( | ||
[".", "s"], | ||
), | ||
"dash": GlobalCategoricalActionSpace( | ||
[".", "d"], | ||
), | ||
"jump": GlobalCategoricalActionSpace( | ||
[".", "j"], | ||
), | ||
} | ||
|
||
def observe(self) -> Observation: | ||
self.state_update = self.connection.read_json() | ||
assert self.state_update['type'] == 'update' | ||
self.entities = to_entities(self.state_update['entities']) | ||
self.me = self._get_own_archer(self.entities) | ||
obs = self._post_observe() | ||
# logging.info(f'Observation: {obs.__dict__}') | ||
return obs | ||
|
||
def _actions_to_command(self, actions: Mapping[ActionName, Action]) -> str: | ||
command = '' | ||
hor: Action = actions['hor'] | ||
assert isinstance(hor, GlobalCategoricalAction) | ||
if hor.index == 1: | ||
command += LEFT | ||
elif hor.index == 2: | ||
command += RIGHT | ||
|
||
ver: Action = actions['hor'] | ||
assert isinstance(ver, GlobalCategoricalAction) | ||
if ver.index == 1: | ||
command += DOWN | ||
elif ver.index == 2: | ||
command += UP | ||
|
||
jump: Action = actions['jump'] | ||
assert isinstance(jump, GlobalCategoricalAction) | ||
if jump.index == 1: | ||
command += JUMP | ||
|
||
dash: Action = actions['dash'] | ||
assert isinstance(dash, GlobalCategoricalAction) | ||
if dash.index == 1: | ||
command += DASH | ||
|
||
shoot: Action = actions['shoot'] | ||
assert isinstance(shoot, GlobalCategoricalAction) | ||
if shoot.index == 1: | ||
command += SHOOT | ||
return command | ||
|
||
def act(self, actions: Mapping[ActionName, Action]) -> Observation: | ||
command = self._actions_to_command(actions) | ||
|
||
resp: dict[str, Any] = dict( | ||
type='commands', | ||
command=command, | ||
id=self.state_update['id'] | ||
) | ||
if self._draw_elems: | ||
resp['draws'] = self._draw_elems | ||
self.connection.write_json(resp) | ||
self._draw_elems.clear() | ||
|
||
return self.observe() | ||
|
||
def _get_own_archer(self, entities: List[Entity]) -> Optional[Entity]: | ||
for e in entities: | ||
if e.type == 'archer': | ||
if e['playerIndex'] == self.index: | ||
return e | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import logging | ||
import random | ||
|
||
from entity_gym.env import Observation, GlobalCategoricalActionMask | ||
|
||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
from common.constants import HH, HW | ||
from common.entity import Entity, Vec2 | ||
from entity_envs.entity_base_env import TowerfallEntityEnv | ||
|
||
from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider | ||
|
||
|
||
class TowerfallEntityEnvImpl(TowerfallEntityEnv): | ||
def __init__(self, | ||
record_path: Optional[str]=None, | ||
verbose: int = 0): | ||
towerfall_provider = TowerfallProcessProvider('entity-env-trainer') | ||
towerfall = towerfall_provider.get_process( | ||
fastrun=True, | ||
reuse=False, | ||
config=dict( | ||
mode='sandbox', | ||
level='3', | ||
agents=[dict(type='remote', team='blue', archer='green')] | ||
)) | ||
super().__init__(towerfall, record_path, verbose) | ||
self.enemy_count = 2 | ||
self.min_distance = 50 | ||
self.max_distance = 100 | ||
self.episode_max_len = 60*6 | ||
self.action_mask = { | ||
'hor': GlobalCategoricalActionMask(np.array([[True, True, True]])), | ||
'ver': GlobalCategoricalActionMask(np.array([[True, True, True]])), | ||
'jump': GlobalCategoricalActionMask(np.array([[True, True]])), | ||
'dash': GlobalCategoricalActionMask(np.array([[True, True]])), | ||
'shoot': GlobalCategoricalActionMask(np.array([[True, True]])), | ||
} | ||
|
||
def _is_reset_valid(self) -> bool: | ||
return True | ||
|
||
def _send_reset(self): | ||
reset_entities = self._get_reset_entities() | ||
self.towerfall.send_reset(reset_entities, verbose=self.verbose) | ||
|
||
def _post_reset(self) -> Observation: | ||
assert self.me, 'No player found after reset' | ||
targets = list(e for e in self.entities if e['isEnemy']) | ||
self.prev_enemy_ids = set(t['id'] for t in targets) | ||
|
||
self.done = False | ||
self.episode_len = 0 | ||
self.reward = 0 | ||
self.prev_arrow_count = len(self.me['arrows']) | ||
return self._get_obs(targets, []) | ||
|
||
def _post_observe(self) -> Observation: | ||
targets = list(e for e in self.entities if e['isEnemy']) | ||
self._update_reward(targets) | ||
self.episode_len += 1 | ||
arrows = list(e for e in self.entities if e['type'] == 'arrow') | ||
return self._get_obs(targets, arrows) | ||
|
||
def _get_reset_entities(self) -> Optional[list[dict]]: | ||
p = Vec2(160, 110) | ||
entities: list[dict[str, Any]] = [dict(type='archer', pos=p.dict())] | ||
for i in range(self.enemy_count): | ||
sign = random.randint(0, 1)*2 - 1 | ||
d = random.uniform(self.min_distance, self.max_distance) * sign | ||
enemy = dict( | ||
type='slime', | ||
pos=(p + Vec2(d, -5)).dict(), | ||
facing=-sign) | ||
entities.append(enemy) | ||
return entities | ||
|
||
def _update_reward(self, enemies: list[Entity]): | ||
''' | ||
Updates the reward and checks if the episode is done. | ||
''' | ||
# Negative reward for getting killed or end of episode | ||
self.reward = 0 | ||
if not self.me or self.episode_len >= self.episode_max_len: | ||
self.done = True | ||
self.reward -= 1 | ||
|
||
# Positive reward for killing an enemy | ||
enemy_ids = set(t['id'] for t in enemies) | ||
for id in self.prev_enemy_ids - enemy_ids: | ||
self.reward += 1 | ||
|
||
if self.me: | ||
arrow_count = len(self.me['arrows']) | ||
delta_arrow = arrow_count - self.prev_arrow_count | ||
if delta_arrow > 0: | ||
self.reward += delta_arrow * 0.2 | ||
else: | ||
self.reward += delta_arrow * 0.1 | ||
self.prev_arrow_count = arrow_count | ||
|
||
|
||
if self.reward != 0: | ||
logging.info(f'Reward: {self.reward}') | ||
|
||
self.prev_enemy_ids = enemy_ids | ||
if len(self.prev_enemy_ids) == 0: | ||
self.done = True | ||
|
||
def limit(self, x: float, a: float, b: float) -> float: | ||
return x+b-a if x < a else x-b+a if x > b else x | ||
|
||
def _get_obs(self, enemies: list[Entity], arrows: list[Entity]) -> Observation: | ||
if not self.me: | ||
return Observation( | ||
done=self.done, | ||
reward=self.reward, | ||
actions=self.action_mask, | ||
global_features=np.array([0, 0], dtype=np.float32), | ||
entities={ | ||
'enemies': [], | ||
'arrows': [] | ||
} | ||
) | ||
|
||
enemie_states = [] | ||
for enemy in enemies: | ||
enemie_states.append(np.array( | ||
[ | ||
self.limit((enemy.p.x - self.me.p.x) / HW, -1, 1), | ||
self.limit((enemy.p.y - self.me.p.y) / HH, -1, 1), | ||
enemy['facing'] | ||
], | ||
dtype=np.float32 | ||
)) | ||
|
||
arrow_states = [] | ||
for arrow in arrows: | ||
arrow_states.append(np.array( | ||
[ | ||
self.limit((arrow.p.x - self.me.p.x) / HW, -1, 1), | ||
self.limit((arrow.p.y - self.me.p.y) / HH, -1, 1), | ||
1 if arrow['state'] == 'stuck' else 0 | ||
], | ||
dtype=np.float32 | ||
)) | ||
|
||
return Observation( | ||
done=self.done, | ||
reward=self.reward, | ||
actions=self.action_mask, | ||
global_features=np.array([ | ||
(self.me.p.y-110) / HH, | ||
self.prev_arrow_count, | ||
], dtype=np.float32), | ||
entities={ | ||
'enemies': enemie_states, | ||
'arrows': arrow_states | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import Any, Optional | ||
|
||
from common.grid import GridView | ||
from entity_envs.entity_env import TowerfallEntityEnvImpl | ||
from envs.connection_provider import TowerfallProcessProvider | ||
|
||
def create_kill_enemy(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0): | ||
towerfall_provider = TowerfallProcessProvider('default') | ||
towerfall = towerfall_provider.get_process( | ||
fastrun=True, | ||
config=dict( | ||
mode='sandbox', | ||
level='1', | ||
fps=90, | ||
agents=[dict(type='remote', team='blue', archer='green')] | ||
)) | ||
return TowerfallEntityEnvImpl( | ||
towerfall=towerfall, | ||
record_path=record_path, | ||
verbose=verbose) |
Oops, something went wrong.