From 174657e8b1125d0a5087440a65381b7bcad13bc1 Mon Sep 17 00:00:00 2001 From: vcanaa Date: Thu, 4 May 2023 14:20:46 -0700 Subject: [PATCH] EntityGym --- .gitignore | 1 + entity_envs/__init__.py | 2 + entity_envs/entity_base_env.py | 183 ++++++++++++++++++++++++++++ entity_envs/entity_env.py | 162 ++++++++++++++++++++++++ entity_envs/predefined_envs.py | 20 +++ envs/connection_provider.py | 31 +++-- envs/entity_env.py | 0 envs/kill_enemy_objective.py | 4 +- tests/test_entity_env.py | 50 ++++++++ train_kill_enemy.py | 8 +- train_kill_enemy_with_entity_env.py | 22 ++++ trainer/trainer.py | 4 +- 12 files changed, 467 insertions(+), 20 deletions(-) create mode 100644 entity_envs/__init__.py create mode 100644 entity_envs/entity_base_env.py create mode 100644 entity_envs/entity_env.py create mode 100644 entity_envs/predefined_envs.py delete mode 100644 envs/entity_env.py create mode 100644 tests/test_entity_env.py create mode 100644 train_kill_enemy_with_entity_env.py diff --git a/.gitignore b/.gitignore index 16c4dd4..64825e8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ config*.json data/ ignore/ +checkpoints/ logs/ models/ optuna/ diff --git a/entity_envs/__init__.py b/entity_envs/__init__.py new file mode 100644 index 0000000..df06936 --- /dev/null +++ b/entity_envs/__init__.py @@ -0,0 +1,2 @@ +from .entity_base_env import * +from .entity_env import * \ No newline at end of file diff --git a/entity_envs/entity_base_env.py b/entity_envs/entity_base_env.py new file mode 100644 index 0000000..c2458c4 --- /dev/null +++ b/entity_envs/entity_base_env.py @@ -0,0 +1,183 @@ +from abc import abstractmethod +import logging +from typing import Any, List, Mapping, Optional, Dict + +from entity_gym.env import GlobalCategoricalAction, Environment, Observation, Action, ActionName, ObsSpace, ActionSpace, GlobalCategoricalActionSpace +from entity_gym.env import Entity as EntityGym + +from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP +from common.entity import Entity, to_entities + +from envs.connection_provider import TowerfallProcess + +class TowerfallEntityEnv(Environment): + def __init__(self, + towerfall: TowerfallProcess, + record_path: Optional[str] = None, + verbose: int = 0): + logging.info('Initializing TowerfallEntityEnv') + self.towerfall = towerfall + self.verbose = verbose + self.connection = self.towerfall.join(timeout=5, verbose=self.verbose) + self.connection.record_path = record_path + self._draw_elems = [] + self.is_init_sent = False + + logging.info('Initialized TowerfallEnv') + + def _is_reset_valid(self) -> bool: + ''' + Use this to make check if the initiallization is valid. + This is useful to collect information about the environment to programmatically construct a sequence of tasks, then + reset the environment again with the proper reseet instructions. + + Returns: + True if the reset is valid, False otherwise, then the environment will be reset again. + ''' + return True + + def _send_reset(self): + ''' + Sends the reset instruction to the game. Overwrite this to change the starting conditions. + Returns: + True if hard reset, False if soft reset. + ''' + self.towerfall.send_reset(verbose=self.verbose) + + @abstractmethod + def _post_reset(self) -> Observation: + ''' + Hook for a gym reset call. Subclass should populate and return the same as a reset in gym API. + + Returns: + A tuple of (observation, info) + ''' + raise NotImplementedError + + @abstractmethod + def _post_observe(self) -> Observation: + raise NotImplementedError + + def draws(self, draw_elem): + ''' + Draws an element on the screen. This is useful for debugging. + ''' + self._draw_elems.append(draw_elem) + + def reset(self) -> Observation: + while True: + self._send_reset() + if not self.is_init_sent: + state_init = self.connection.read_json() + assert state_init['type'] == 'init', state_init['type'] + self.index = state_init['index'] + self.connection.write_json(dict(type='result', success=True)) + + self.state_scenario = self.connection.read_json() + assert self.state_scenario['type'] == 'scenario', self.state_scenario['type'] + self.connection.write_json(dict(type='result', success=True)) + self.is_init_sent = True + else: + self.connection.write_json(dict(type='commands', command="", id=self.state_update['id'])) + + self.frame = 0 + self.state_update = self.connection.read_json() + assert self.state_update['type'] == 'update', self.state_update['type'] + self.entities = to_entities(self.state_update['entities']) + self.me = self._get_own_archer(self.entities) + if self._is_reset_valid(): + break + + return self._post_reset() + + @classmethod + def obs_space(cls) -> ObsSpace: + return ObsSpace( + global_features=["y", "arrow_count"], + entities={ + "enemies": EntityGym(features=["x", "y", "facing"]), + "arrows": EntityGym(features=["x", "y", "stuck"]), + }) + + @classmethod + def action_space(cls) -> Dict[ActionName, ActionSpace]: + return { + "hor": GlobalCategoricalActionSpace( + [".", "l", "r"], + ), + "ver": GlobalCategoricalActionSpace( + [".", "d", "u"], + ), + "shoot": GlobalCategoricalActionSpace( + [".", "s"], + ), + "dash": GlobalCategoricalActionSpace( + [".", "d"], + ), + "jump": GlobalCategoricalActionSpace( + [".", "j"], + ), + } + + def observe(self) -> Observation: + self.state_update = self.connection.read_json() + assert self.state_update['type'] == 'update' + self.entities = to_entities(self.state_update['entities']) + self.me = self._get_own_archer(self.entities) + obs = self._post_observe() + # logging.info(f'Observation: {obs.__dict__}') + return obs + + def _actions_to_command(self, actions: Mapping[ActionName, Action]) -> str: + command = '' + hor: Action = actions['hor'] + assert isinstance(hor, GlobalCategoricalAction) + if hor.index == 1: + command += LEFT + elif hor.index == 2: + command += RIGHT + + ver: Action = actions['hor'] + assert isinstance(ver, GlobalCategoricalAction) + if ver.index == 1: + command += DOWN + elif ver.index == 2: + command += UP + + jump: Action = actions['jump'] + assert isinstance(jump, GlobalCategoricalAction) + if jump.index == 1: + command += JUMP + + dash: Action = actions['dash'] + assert isinstance(dash, GlobalCategoricalAction) + if dash.index == 1: + command += DASH + + shoot: Action = actions['shoot'] + assert isinstance(shoot, GlobalCategoricalAction) + if shoot.index == 1: + command += SHOOT + return command + + def act(self, actions: Mapping[ActionName, Action]) -> Observation: + command = self._actions_to_command(actions) + + resp: dict[str, Any] = dict( + type='commands', + command=command, + id=self.state_update['id'] + ) + if self._draw_elems: + resp['draws'] = self._draw_elems + self.connection.write_json(resp) + self._draw_elems.clear() + + return self.observe() + + def _get_own_archer(self, entities: List[Entity]) -> Optional[Entity]: + for e in entities: + if e.type == 'archer': + if e['playerIndex'] == self.index: + return e + return None \ No newline at end of file diff --git a/entity_envs/entity_env.py b/entity_envs/entity_env.py new file mode 100644 index 0000000..b548dd6 --- /dev/null +++ b/entity_envs/entity_env.py @@ -0,0 +1,162 @@ +import logging +import random + +from entity_gym.env import Observation, GlobalCategoricalActionMask + +from typing import Any, Optional + +import numpy as np +from common.constants import HH, HW +from common.entity import Entity, Vec2 +from entity_envs.entity_base_env import TowerfallEntityEnv + +from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider + + +class TowerfallEntityEnvImpl(TowerfallEntityEnv): + def __init__(self, + record_path: Optional[str]=None, + verbose: int = 0): + towerfall_provider = TowerfallProcessProvider('entity-env-trainer') + towerfall = towerfall_provider.get_process( + fastrun=True, + reuse=False, + config=dict( + mode='sandbox', + level='3', + agents=[dict(type='remote', team='blue', archer='green')] + )) + super().__init__(towerfall, record_path, verbose) + self.enemy_count = 2 + self.min_distance = 50 + self.max_distance = 100 + self.episode_max_len = 60*6 + self.action_mask = { + 'hor': GlobalCategoricalActionMask(np.array([[True, True, True]])), + 'ver': GlobalCategoricalActionMask(np.array([[True, True, True]])), + 'jump': GlobalCategoricalActionMask(np.array([[True, True]])), + 'dash': GlobalCategoricalActionMask(np.array([[True, True]])), + 'shoot': GlobalCategoricalActionMask(np.array([[True, True]])), + } + + def _is_reset_valid(self) -> bool: + return True + + def _send_reset(self): + reset_entities = self._get_reset_entities() + self.towerfall.send_reset(reset_entities, verbose=self.verbose) + + def _post_reset(self) -> Observation: + assert self.me, 'No player found after reset' + targets = list(e for e in self.entities if e['isEnemy']) + self.prev_enemy_ids = set(t['id'] for t in targets) + + self.done = False + self.episode_len = 0 + self.reward = 0 + self.prev_arrow_count = len(self.me['arrows']) + return self._get_obs(targets, []) + + def _post_observe(self) -> Observation: + targets = list(e for e in self.entities if e['isEnemy']) + self._update_reward(targets) + self.episode_len += 1 + arrows = list(e for e in self.entities if e['type'] == 'arrow') + return self._get_obs(targets, arrows) + + def _get_reset_entities(self) -> Optional[list[dict]]: + p = Vec2(160, 110) + entities: list[dict[str, Any]] = [dict(type='archer', pos=p.dict())] + for i in range(self.enemy_count): + sign = random.randint(0, 1)*2 - 1 + d = random.uniform(self.min_distance, self.max_distance) * sign + enemy = dict( + type='slime', + pos=(p + Vec2(d, -5)).dict(), + facing=-sign) + entities.append(enemy) + return entities + + def _update_reward(self, enemies: list[Entity]): + ''' + Updates the reward and checks if the episode is done. + ''' + # Negative reward for getting killed or end of episode + self.reward = 0 + if not self.me or self.episode_len >= self.episode_max_len: + self.done = True + self.reward -= 1 + + # Positive reward for killing an enemy + enemy_ids = set(t['id'] for t in enemies) + for id in self.prev_enemy_ids - enemy_ids: + self.reward += 1 + + if self.me: + arrow_count = len(self.me['arrows']) + delta_arrow = arrow_count - self.prev_arrow_count + if delta_arrow > 0: + self.reward += delta_arrow * 0.2 + else: + self.reward += delta_arrow * 0.1 + self.prev_arrow_count = arrow_count + + + if self.reward != 0: + logging.info(f'Reward: {self.reward}') + + self.prev_enemy_ids = enemy_ids + if len(self.prev_enemy_ids) == 0: + self.done = True + + def limit(self, x: float, a: float, b: float) -> float: + return x+b-a if x < a else x-b+a if x > b else x + + def _get_obs(self, enemies: list[Entity], arrows: list[Entity]) -> Observation: + if not self.me: + return Observation( + done=self.done, + reward=self.reward, + actions=self.action_mask, + global_features=np.array([0, 0], dtype=np.float32), + entities={ + 'enemies': [], + 'arrows': [] + } + ) + + enemie_states = [] + for enemy in enemies: + enemie_states.append(np.array( + [ + self.limit((enemy.p.x - self.me.p.x) / HW, -1, 1), + self.limit((enemy.p.y - self.me.p.y) / HH, -1, 1), + enemy['facing'] + ], + dtype=np.float32 + )) + + arrow_states = [] + for arrow in arrows: + arrow_states.append(np.array( + [ + self.limit((arrow.p.x - self.me.p.x) / HW, -1, 1), + self.limit((arrow.p.y - self.me.p.y) / HH, -1, 1), + 1 if arrow['state'] == 'stuck' else 0 + ], + dtype=np.float32 + )) + + return Observation( + done=self.done, + reward=self.reward, + actions=self.action_mask, + global_features=np.array([ + (self.me.p.y-110) / HH, + self.prev_arrow_count, + ], dtype=np.float32), + entities={ + 'enemies': enemie_states, + 'arrows': arrow_states + } + ) diff --git a/entity_envs/predefined_envs.py b/entity_envs/predefined_envs.py new file mode 100644 index 0000000..5d012de --- /dev/null +++ b/entity_envs/predefined_envs.py @@ -0,0 +1,20 @@ +from typing import Any, Optional + +from common.grid import GridView +from entity_envs.entity_env import TowerfallEntityEnvImpl +from envs.connection_provider import TowerfallProcessProvider + +def create_kill_enemy(configs: dict[str, Any], record_path: Optional[str]=None, verbose=0): + towerfall_provider = TowerfallProcessProvider('default') + towerfall = towerfall_provider.get_process( + fastrun=True, + config=dict( + mode='sandbox', + level='1', + fps=90, + agents=[dict(type='remote', team='blue', archer='green')] + )) + return TowerfallEntityEnvImpl( + towerfall=towerfall, + record_path=record_path, + verbose=verbose) \ No newline at end of file diff --git a/envs/connection_provider.py b/envs/connection_provider.py index 1793407..8294980 100644 --- a/envs/connection_provider.py +++ b/envs/connection_provider.py @@ -117,22 +117,24 @@ def __init__(self, name: str): ] ) - def get_process(self, fastrun: bool = False, nographics: bool = False, config = None, verbose=0) -> TowerfallProcess: + def get_process(self, fastrun: bool = False, nographics: bool = False, config = None, verbose=0, reuse: bool = True) -> TowerfallProcess: if not config: config = self.default_config selected_process = None while not selected_process: - # Try to find an existing process that is not in use - def is_suitable_process(process: TowerfallProcess): - if process.fastrun != fastrun: - return False - if process.nographics != nographics: - return False - if process.pid in self._processes_in_use: - return False - return True - selected_process = next((p for p in self.processes if is_suitable_process(p)), None) + selected_process = None + if reuse: + # Try to find an existing process that is not in use + def is_suitable_process(process: TowerfallProcess): + if process.fastrun != fastrun: + return False + if process.nographics != nographics: + return False + if process.pid in self._processes_in_use: + return False + return True + selected_process = next((p for p in self.processes if is_suitable_process(p)), None) # If no process was found, start a new one if not selected_process: @@ -161,8 +163,13 @@ def release_process(self, process: TowerfallProcess): self._processes_in_use.remove(process.pid) def close(self): + logging.info('Closing all processes...') for process in self.processes: - os.kill(process.pid, signal.SIGTERM) + try: + os.kill(process.pid, signal.SIGTERM) + except Exception as ex: + logging.error(f'Failed to kill process {process.pid}: {ex}') + continue def _get_port(self, pid: int) -> int: port_path = os.path.join(self.towerfall_path, 'ports', str(pid)) diff --git a/envs/entity_env.py b/envs/entity_env.py deleted file mode 100644 index e69de29..0000000 diff --git a/envs/kill_enemy_objective.py b/envs/kill_enemy_objective.py index ce06aa5..e0278c7 100644 --- a/envs/kill_enemy_objective.py +++ b/envs/kill_enemy_objective.py @@ -75,7 +75,7 @@ def _update_reward(self, player: Optional[Entity], targets: list[Entity]): self.rew = 0 if not player or self.episode_len >= self.episode_max_len: self.done = True - self.rew -= self.bounty + self.rew -= self.bounty / 5 if len(targets) < self.n_targets_prev: self.rew = self.bounty * (self.n_targets_prev - len(targets)) @@ -99,7 +99,7 @@ def _update_obs(self, player: Optional[Entity], targets: list[Entity], obs_dict: target_by_dist.sort(key=lambda x: x[0]) for i, (_, target) in enumerate(target_by_dist): - obs_target[i*3 + 1] = 1 + obs_target[i*3] = 1 obs_target[i*3 + 1] = self.limit(target.p.x - player.p.x / HW, -1, 1) obs_target[i*3 + 2] = self.limit(target.p.y - player.p.y / HH, -1, 1) obs_dict['targets'] = obs_target diff --git a/tests/test_entity_env.py b/tests/test_entity_env.py new file mode 100644 index 0000000..65778d5 --- /dev/null +++ b/tests/test_entity_env.py @@ -0,0 +1,50 @@ +import sys +sys.path.insert(0, 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall/aimod') + +import random +from entity_gym.env import GlobalCategoricalAction, GlobalCategoricalActionSpace +from entity_envs.entity_env import TowerfallEntityEnvImpl +from envs.connection_provider import TowerfallProcessProvider + + +import logging + +from typing import Any + +class NoLevelFormatter(logging.Formatter): + def format(self, record): + return record.getMessage() + +logging.basicConfig(level=logging.INFO) +logging.getLogger().handlers[0].setFormatter(NoLevelFormatter()) + +def create_env() -> TowerfallEntityEnvImpl: + towerfall_provider = TowerfallProcessProvider('test-entity-env') + towerfall = towerfall_provider.get_process( + fastrun=True, + config=dict( + mode='sandbox', + level='2', + fps=90, + agents=[dict(type='remote', team='blue', archer='green')] + ), verbose=1) + env = TowerfallEntityEnvImpl( + verbose=0) + return env + + +env = create_env() + +logging.info('Evaluating') +n_episodes = 500 +env.reset() +for ep in range(n_episodes): + actions = {} + for action_name, action in env.action_space().items(): + assert isinstance(action, GlobalCategoricalActionSpace) + idx = random.randint(0, len(action.index_to_label) - 1) + actions[action_name] = GlobalCategoricalAction(idx, action.index_to_label[idx]) + obs = env.act(actions) + if obs.done: + env.reset() + diff --git a/train_kill_enemy.py b/train_kill_enemy.py index 1080443..a62ef69 100644 --- a/train_kill_enemy.py +++ b/train_kill_enemy.py @@ -17,10 +17,10 @@ def get_configs(): policy = 'MultiInputPolicy', n_steps = 1024, batch_size = 64, - learning_rate= 1e-4, + learning_rate= 2*1e-5, policy_kwargs= dict( # net_arch = [64, 64] - net_arch = [256] * 2 + net_arch = [128] * 2 ), ), grid_params=dict( @@ -33,12 +33,12 @@ def get_configs(): min_distance=50, max_distance=120, bounty=5, - episode_max_len=60*4 + episode_max_len=60*6 ), learn_params = dict(), actions_params = dict( can_shoot = False, - can_dash = False, + can_dash = True, ) ) diff --git a/train_kill_enemy_with_entity_env.py b/train_kill_enemy_with_entity_env.py new file mode 100644 index 0000000..0b4af73 --- /dev/null +++ b/train_kill_enemy_with_entity_env.py @@ -0,0 +1,22 @@ +from enn_trainer import TrainConfig, State, init_train_state, train +import hyperstate +from common import logging_options + +from entity_envs.entity_env import TowerfallEntityEnvImpl +from envs.connection_provider import TowerfallProcessProvider + + +logging_options.set_default() + + +@hyperstate.stateful_command(TrainConfig, State, init_train_state) +def main(state_manager: hyperstate.StateManager) -> None: + try: + train(state_manager=state_manager, env=TowerfallEntityEnvImpl) + finally: + towerfall_provider = TowerfallProcessProvider('entity-env-trainer') + towerfall_provider.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/trainer/trainer.py b/trainer/trainer.py index da040f0..5e03086 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -142,12 +142,12 @@ def fork_training(self, model, _ = self.load_from_trial(load_project_name, load_trial_name, load_model_name, monitored_env) self._train_model(model, env, total_steps, configs, project_name, trial_name) - def evaluate_model(self, env_fn: Callable[[dict[str, Any]], TowerfallBlankEnv], n_episodes: int, project_name: str, trial_name: str, model_name: str): + def evaluate_model(self, env_fn: Callable[[dict[str, Any]], Env], n_episodes: int, project_name: str, trial_name: str, model_name: str): model, configs = self.load_from_trial(project_name, trial_name, model_name) env = env_fn(configs) evaluate_policy(model, env=env, n_eval_episodes=n_episodes, render=False, deterministic=False) - def evaluate_all_models(self, env_fn: Callable[[dict[str, Any]], TowerfallBlankEnv], n_episodes: int, project_name: str, trial_name: str): + def evaluate_all_models(self, env_fn: Callable[[dict[str, Any]], Env], n_episodes: int, project_name: str, trial_name: str): trial_path = self.get_trial_path(project_name, trial_name) logging.info(f'Loading experiment from {trial_path}') with open(os.path.join(trial_path, 'hparams.json'), 'r') as file: