Skip to content

Commit

Permalink
EntityGym
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanaa committed May 4, 2023
1 parent 396cea8 commit 174657e
Show file tree
Hide file tree
Showing 12 changed files with 467 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ config*.json

data/
ignore/
checkpoints/
logs/
models/
optuna/
Expand Down
2 changes: 2 additions & 0 deletions entity_envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .entity_base_env import *
from .entity_env import *
183 changes: 183 additions & 0 deletions entity_envs/entity_base_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from abc import abstractmethod
import logging
from typing import Any, List, Mapping, Optional, Dict

from entity_gym.env import GlobalCategoricalAction, Environment, Observation, Action, ActionName, ObsSpace, ActionSpace, GlobalCategoricalActionSpace
from entity_gym.env import Entity as EntityGym

from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP
from common.entity import Entity, to_entities

from envs.connection_provider import TowerfallProcess

class TowerfallEntityEnv(Environment):
def __init__(self,
towerfall: TowerfallProcess,
record_path: Optional[str] = None,
verbose: int = 0):
logging.info('Initializing TowerfallEntityEnv')
self.towerfall = towerfall
self.verbose = verbose
self.connection = self.towerfall.join(timeout=5, verbose=self.verbose)
self.connection.record_path = record_path
self._draw_elems = []
self.is_init_sent = False

logging.info('Initialized TowerfallEnv')

def _is_reset_valid(self) -> bool:
'''
Use this to make check if the initiallization is valid.
This is useful to collect information about the environment to programmatically construct a sequence of tasks, then
reset the environment again with the proper reseet instructions.
Returns:
True if the reset is valid, False otherwise, then the environment will be reset again.
'''
return True

def _send_reset(self):
'''
Sends the reset instruction to the game. Overwrite this to change the starting conditions.
Returns:
True if hard reset, False if soft reset.
'''
self.towerfall.send_reset(verbose=self.verbose)

@abstractmethod
def _post_reset(self) -> Observation:
'''
Hook for a gym reset call. Subclass should populate and return the same as a reset in gym API.
Returns:
A tuple of (observation, info)
'''
raise NotImplementedError

@abstractmethod
def _post_observe(self) -> Observation:
raise NotImplementedError

def draws(self, draw_elem):
'''
Draws an element on the screen. This is useful for debugging.
'''
self._draw_elems.append(draw_elem)

def reset(self) -> Observation:
while True:
self._send_reset()
if not self.is_init_sent:
state_init = self.connection.read_json()
assert state_init['type'] == 'init', state_init['type']
self.index = state_init['index']
self.connection.write_json(dict(type='result', success=True))

self.state_scenario = self.connection.read_json()
assert self.state_scenario['type'] == 'scenario', self.state_scenario['type']
self.connection.write_json(dict(type='result', success=True))
self.is_init_sent = True
else:
self.connection.write_json(dict(type='commands', command="", id=self.state_update['id']))

self.frame = 0
self.state_update = self.connection.read_json()
assert self.state_update['type'] == 'update', self.state_update['type']
self.entities = to_entities(self.state_update['entities'])
self.me = self._get_own_archer(self.entities)
if self._is_reset_valid():
break

return self._post_reset()

@classmethod
def obs_space(cls) -> ObsSpace:
return ObsSpace(
global_features=["y", "arrow_count"],
entities={
"enemies": EntityGym(features=["x", "y", "facing"]),
"arrows": EntityGym(features=["x", "y", "stuck"]),
})

@classmethod
def action_space(cls) -> Dict[ActionName, ActionSpace]:
return {
"hor": GlobalCategoricalActionSpace(
[".", "l", "r"],
),
"ver": GlobalCategoricalActionSpace(
[".", "d", "u"],
),
"shoot": GlobalCategoricalActionSpace(
[".", "s"],
),
"dash": GlobalCategoricalActionSpace(
[".", "d"],
),
"jump": GlobalCategoricalActionSpace(
[".", "j"],
),
}

def observe(self) -> Observation:
self.state_update = self.connection.read_json()
assert self.state_update['type'] == 'update'
self.entities = to_entities(self.state_update['entities'])
self.me = self._get_own_archer(self.entities)
obs = self._post_observe()
# logging.info(f'Observation: {obs.__dict__}')
return obs

def _actions_to_command(self, actions: Mapping[ActionName, Action]) -> str:
command = ''
hor: Action = actions['hor']
assert isinstance(hor, GlobalCategoricalAction)
if hor.index == 1:
command += LEFT
elif hor.index == 2:
command += RIGHT

ver: Action = actions['hor']
assert isinstance(ver, GlobalCategoricalAction)
if ver.index == 1:
command += DOWN
elif ver.index == 2:
command += UP

jump: Action = actions['jump']
assert isinstance(jump, GlobalCategoricalAction)
if jump.index == 1:
command += JUMP

dash: Action = actions['dash']
assert isinstance(dash, GlobalCategoricalAction)
if dash.index == 1:
command += DASH

shoot: Action = actions['shoot']
assert isinstance(shoot, GlobalCategoricalAction)
if shoot.index == 1:
command += SHOOT
return command

def act(self, actions: Mapping[ActionName, Action]) -> Observation:
command = self._actions_to_command(actions)

resp: dict[str, Any] = dict(
type='commands',
command=command,
id=self.state_update['id']
)
if self._draw_elems:
resp['draws'] = self._draw_elems
self.connection.write_json(resp)
self._draw_elems.clear()

return self.observe()

def _get_own_archer(self, entities: List[Entity]) -> Optional[Entity]:
for e in entities:
if e.type == 'archer':
if e['playerIndex'] == self.index:
return e
return None
162 changes: 162 additions & 0 deletions entity_envs/entity_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import logging
import random

from entity_gym.env import Observation, GlobalCategoricalActionMask

from typing import Any, Optional

import numpy as np
from common.constants import HH, HW
from common.entity import Entity, Vec2
from entity_envs.entity_base_env import TowerfallEntityEnv

from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider


class TowerfallEntityEnvImpl(TowerfallEntityEnv):
def __init__(self,
record_path: Optional[str]=None,
verbose: int = 0):
towerfall_provider = TowerfallProcessProvider('entity-env-trainer')
towerfall = towerfall_provider.get_process(
fastrun=True,
reuse=False,
config=dict(
mode='sandbox',
level='3',
agents=[dict(type='remote', team='blue', archer='green')]
))
super().__init__(towerfall, record_path, verbose)
self.enemy_count = 2
self.min_distance = 50
self.max_distance = 100
self.episode_max_len = 60*6
self.action_mask = {
'hor': GlobalCategoricalActionMask(np.array([[True, True, True]])),
'ver': GlobalCategoricalActionMask(np.array([[True, True, True]])),
'jump': GlobalCategoricalActionMask(np.array([[True, True]])),
'dash': GlobalCategoricalActionMask(np.array([[True, True]])),
'shoot': GlobalCategoricalActionMask(np.array([[True, True]])),
}

def _is_reset_valid(self) -> bool:
return True

def _send_reset(self):
reset_entities = self._get_reset_entities()
self.towerfall.send_reset(reset_entities, verbose=self.verbose)

def _post_reset(self) -> Observation:
assert self.me, 'No player found after reset'
targets = list(e for e in self.entities if e['isEnemy'])
self.prev_enemy_ids = set(t['id'] for t in targets)

self.done = False
self.episode_len = 0
self.reward = 0
self.prev_arrow_count = len(self.me['arrows'])
return self._get_obs(targets, [])

def _post_observe(self) -> Observation:
targets = list(e for e in self.entities if e['isEnemy'])
self._update_reward(targets)
self.episode_len += 1
arrows = list(e for e in self.entities if e['type'] == 'arrow')
return self._get_obs(targets, arrows)

def _get_reset_entities(self) -> Optional[list[dict]]:
p = Vec2(160, 110)
entities: list[dict[str, Any]] = [dict(type='archer', pos=p.dict())]
for i in range(self.enemy_count):
sign = random.randint(0, 1)*2 - 1
d = random.uniform(self.min_distance, self.max_distance) * sign
enemy = dict(
type='slime',
pos=(p + Vec2(d, -5)).dict(),
facing=-sign)
entities.append(enemy)
return entities

def _update_reward(self, enemies: list[Entity]):
'''
Updates the reward and checks if the episode is done.
'''
# Negative reward for getting killed or end of episode
self.reward = 0
if not self.me or self.episode_len >= self.episode_max_len:
self.done = True
self.reward -= 1

# Positive reward for killing an enemy
enemy_ids = set(t['id'] for t in enemies)
for id in self.prev_enemy_ids - enemy_ids:
self.reward += 1

if self.me:
arrow_count = len(self.me['arrows'])
delta_arrow = arrow_count - self.prev_arrow_count
if delta_arrow > 0:
self.reward += delta_arrow * 0.2
else:
self.reward += delta_arrow * 0.1
self.prev_arrow_count = arrow_count


if self.reward != 0:
logging.info(f'Reward: {self.reward}')

self.prev_enemy_ids = enemy_ids
if len(self.prev_enemy_ids) == 0:
self.done = True

def limit(self, x: float, a: float, b: float) -> float:
return x+b-a if x < a else x-b+a if x > b else x

def _get_obs(self, enemies: list[Entity], arrows: list[Entity]) -> Observation:
if not self.me:
return Observation(
done=self.done,
reward=self.reward,
actions=self.action_mask,
global_features=np.array([0, 0], dtype=np.float32),
entities={
'enemies': [],
'arrows': []
}
)

enemie_states = []
for enemy in enemies:
enemie_states.append(np.array(
[
self.limit((enemy.p.x - self.me.p.x) / HW, -1, 1),
self.limit((enemy.p.y - self.me.p.y) / HH, -1, 1),
enemy['facing']
],
dtype=np.float32
))

arrow_states = []
for arrow in arrows:
arrow_states.append(np.array(
[
self.limit((arrow.p.x - self.me.p.x) / HW, -1, 1),
self.limit((arrow.p.y - self.me.p.y) / HH, -1, 1),
1 if arrow['state'] == 'stuck' else 0
],
dtype=np.float32
))

return Observation(
done=self.done,
reward=self.reward,
actions=self.action_mask,
global_features=np.array([
(self.me.p.y-110) / HH,
self.prev_arrow_count,
], dtype=np.float32),
entities={
'enemies': enemie_states,
'arrows': arrow_states
}
)
20 changes: 20 additions & 0 deletions entity_envs/predefined_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Optional

from common.grid import GridView
from entity_envs.entity_env import TowerfallEntityEnvImpl
from envs.connection_provider import TowerfallProcessProvider

def create_kill_enemy(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0):
towerfall_provider = TowerfallProcessProvider('default')
towerfall = towerfall_provider.get_process(
fastrun=True,
config=dict(
mode='sandbox',
level='1',
fps=90,
agents=[dict(type='remote', team='blue', archer='green')]
))
return TowerfallEntityEnvImpl(
towerfall=towerfall,
record_path=record_path,
verbose=verbose)
Loading

0 comments on commit 174657e

Please sign in to comment.