From dd8cd2b2da466e548366be4e6fd21c6effa0578b Mon Sep 17 00:00:00 2001 From: vcanaa Date: Sun, 7 May 2023 18:32:58 -0700 Subject: [PATCH] More experiments with entitygym --- entity_envs/entity_base_env.py | 21 ++++-- entity_envs/entity_env.py | 16 +---- envs/connection_provider.py | 101 +++++++++++++++++++++------- train commands | 7 ++ train_kill_enemies.ron | 64 ++++++++++++++++++ train_kill_enemy_with_entity_env.py | 5 +- 6 files changed, 169 insertions(+), 45 deletions(-) create mode 100644 train commands create mode 100644 train_kill_enemies.ron diff --git a/entity_envs/entity_base_env.py b/entity_envs/entity_base_env.py index c2458c4..7dd0c66 100644 --- a/entity_envs/entity_base_env.py +++ b/entity_envs/entity_base_env.py @@ -8,17 +8,28 @@ from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP from common.entity import Entity, to_entities -from envs.connection_provider import TowerfallProcess +from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider class TowerfallEntityEnv(Environment): def __init__(self, - towerfall: TowerfallProcess, + towerfall: Optional[TowerfallProcess] = None, record_path: Optional[str] = None, verbose: int = 0): logging.info('Initializing TowerfallEntityEnv') - self.towerfall = towerfall self.verbose = verbose - self.connection = self.towerfall.join(timeout=5, verbose=self.verbose) + if towerfall: + self.connection = towerfall.join(timeout=5, verbose=self.verbose) + self.towerfall = towerfall + else: + self.connection, self.towerfall = TowerfallProcessProvider().join_new( + fastrun=True, + # nographics=True, + config=dict( + mode='sandbox', + level='3', + agents=[dict(type='remote', team='blue', archer='green')]), + timeout=5, + verbose=self.verbose) self.connection.record_path = record_path self._draw_elems = [] self.is_init_sent = False @@ -27,7 +38,7 @@ def __init__(self, def _is_reset_valid(self) -> bool: ''' - Use this to make check if the initiallization is valid. + Use this to check if the initiallization is valid. This is useful to collect information about the environment to programmatically construct a sequence of tasks, then reset the environment again with the proper reseet instructions. diff --git a/entity_envs/entity_env.py b/entity_envs/entity_env.py index b548dd6..e00db3f 100644 --- a/entity_envs/entity_env.py +++ b/entity_envs/entity_env.py @@ -17,16 +17,7 @@ class TowerfallEntityEnvImpl(TowerfallEntityEnv): def __init__(self, record_path: Optional[str]=None, verbose: int = 0): - towerfall_provider = TowerfallProcessProvider('entity-env-trainer') - towerfall = towerfall_provider.get_process( - fastrun=True, - reuse=False, - config=dict( - mode='sandbox', - level='3', - agents=[dict(type='remote', team='blue', archer='green')] - )) - super().__init__(towerfall, record_path, verbose) + super().__init__(record_path=record_path, verbose=verbose) self.enemy_count = 2 self.min_distance = 50 self.max_distance = 100 @@ -101,9 +92,8 @@ def _update_reward(self, enemies: list[Entity]): self.reward += delta_arrow * 0.1 self.prev_arrow_count = arrow_count - - if self.reward != 0: - logging.info(f'Reward: {self.reward}') + # if self.reward != 0: + # logging.info(f'Reward: {self.reward}') self.prev_enemy_ids = enemy_ids if len(self.prev_enemy_ids) == 0: diff --git a/envs/connection_provider.py b/envs/connection_provider.py index 8294980..5245c03 100644 --- a/envs/connection_provider.py +++ b/envs/connection_provider.py @@ -9,7 +9,9 @@ from common import Connection -from typing import Any, Optional +from typing import Any, Optional, Tuple + +from common.namedmutex import NamedMutex _HOST = '127.0.0.1' @@ -88,26 +90,30 @@ class TowerfallProcessProvider: params name: Name of the connection provider. Used to separate different connection providers states. ''' - def __init__(self, name: str): - self.towerfall_path = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall' + def __init__(self, name: Optional[str] = None, + # towerfall_path: str = 'C:/Users/vcanaa/towerfall/TowerFall', + towerfall_path: str = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall'): + self.towerfall_path = towerfall_path self.towerfall_path_exe = os.path.join(self.towerfall_path, 'TowerFall.exe') - self.connection_path = os.path.join('.connection_provider', name) - os.makedirs(self.connection_path, exist_ok=True) - self.state_path = os.path.join(self.connection_path, 'state.json') + self.name = name self.processes = [] - if os.path.exists(self.state_path): - with open(self.state_path, 'r') as file: - for process_data in json.loads(file.read()): - try: - psutil.Process(process_data['pid']) - self.processes.append(TowerfallProcess(**process_data)) - except psutil.NoSuchProcess: - continue + if self.name: + self.connection_path = os.path.join('.connection_provider', self.name) + os.makedirs(self.connection_path, exist_ok=True) + self.state_path = os.path.join(self.connection_path, 'state.json') + + if os.path.exists(self.state_path): + with open(self.state_path, 'r') as file: + for process_data in json.loads(file.read()): + try: + psutil.Process(process_data['pid']) + self.processes.append(TowerfallProcess(**process_data)) + except psutil.NoSuchProcess: + continue + self._save_state() self._processes_in_use = set() - self._save_state() - self.default_config = dict( mode='sandbox', level='2', @@ -136,7 +142,7 @@ def is_suitable_process(process: TowerfallProcess): return True selected_process = next((p for p in self.processes if is_suitable_process(p)), None) - # If no process was found, start a new one + # If no process can be reused, start a new one if not selected_process: logging.info(f'Starting new process {self.towerfall_path_exe}') pargs = [self.towerfall_path_exe, '--noconfig'] @@ -144,11 +150,15 @@ def is_suitable_process(process: TowerfallProcess): pargs.append('--fastrun') if nographics: pargs.append('--nographics') - process = Popen(pargs, cwd=self.towerfall_path) - port = self._get_port(process.pid) - selected_process = TowerfallProcess(process.pid, port, fastrun, nographics) - self.processes.append(selected_process) - self._save_state() + # Multiple TowerFall.exe can't be started at the same time, due to conflict accessing Content folder. + with NamedMutex(f'TowerfallProcessProvider_{self.name}'): + process = Popen(pargs, cwd=self.towerfall_path) + port = self._get_port(process.pid) + selected_process = TowerfallProcess(process.pid, port, fastrun, nographics) + self.processes.append(selected_process) + self._save_state() + time.sleep(2) # Give some time for game to load. There is currently no way to tell if the game loaded. + try: selected_process.send_config(config, verbose=verbose) @@ -162,8 +172,34 @@ def is_suitable_process(process: TowerfallProcess): def release_process(self, process: TowerfallProcess): self._processes_in_use.remove(process.pid) + def join_new(self, fastrun=True, nographics=False, config = None, timeout=2, verbose=0) -> Tuple[Connection, TowerfallProcess]: + connection = None + process = None + logging.info('Create a new process and join') + while not connection or not process: + try: + process = self.get_process(fastrun, nographics, config, verbose, reuse=False) + connection = process.join(timeout, verbose) + except Exception as ex: + logging.error(f'Failed to create and join new process: {ex}') + if process: + self.kill_process(process.pid) + if connection: + connection.close() + + return connection, process + + def kill_process(self, pid): + try: + os.kill(pid, signal.SIGTERM) + except Exception as ex: + logging.error(f'Failed to kill process {pid}: {ex}') + finally: + self.processes.remove(next(p for p in self.processes if p.pid == pid)) + self._save_state() + def close(self): - logging.info('Closing all processes...') + logging.info(f'Closing all processes in context {self.name}...') for process in self.processes: try: os.kill(process.pid, signal.SIGTERM) @@ -171,6 +207,20 @@ def close(self): logging.error(f'Failed to kill process {process.pid}: {ex}') continue + @classmethod + def close_all(cls): + logging.info('Closing all TowerFall.exe processes...') + for process in psutil.process_iter(attrs=['pid', 'name']): + # logging.info(f'Checking process {process.pid} {process.name()}') + if process.name() != 'TowerFall.exe': + continue + try: + logging.info(f'Killing process {process.pid}...') + os.kill(process.pid, signal.SIGTERM) + except Exception as ex: + logging.error(f'Failed to kill process {process.pid}: {ex}') + continue + def _get_port(self, pid: int) -> int: port_path = os.path.join(self.towerfall_path, 'ports', str(pid)) tries = 0 @@ -182,8 +232,9 @@ def _get_port(self, pid: int) -> int: return int(file.readline()) def _save_state(self): - with open(self.state_path, 'w') as file: - file.write(json.dumps([p.to_dict() for p in self.processes], indent=2)) + if self.name: + with open(self.state_path, 'w') as file: + file.write(json.dumps([p.to_dict() for p in self.processes], indent=2)) def _match_config(self, config1, config2): return False diff --git a/train commands b/train commands new file mode 100644 index 0000000..7e7ac52 --- /dev/null +++ b/train commands @@ -0,0 +1,7 @@ +python train_kill_enemy_with_entity_env.py ^ + total_timesteps=2000 rollout.steps=1024 ^ + rollout.num_envs=1 ^ + rollout.processes=1 ^ + optim.bs=256 ^ + optim.lr=0.0001 ^ + --checkpoint-dir=checkpoints \ No newline at end of file diff --git a/train_kill_enemies.ron b/train_kill_enemies.ron new file mode 100644 index 0000000..93b41f2 --- /dev/null +++ b/train_kill_enemies.ron @@ -0,0 +1,64 @@ +TrainConfig( + version: 4, + env: EnvConfig( + kwargs: "{}", + id: "MoveToOrigin", + validate: true, + ), + net: RogueNetConfig( + embd_pdrop: 0.0, + resid_pdrop: 0.0, + attn_pdrop: 0.0, + n_layer: 2, + n_head: 2, + d_model: 32, + pooling: None, + relpos_encoding: None, + d_qk: 16, + translation: None, + ), + optim: OptimizerConfig( + lr: 0.0001, + lr_warmup_steps: None, + bs: 256, + weight_decay: 0.0, + micro_bs: None, + anneal_lr: true, + update_epochs: 3, + max_grad_norm: 2.0, + ), + ppo: PPOConfig( + gae: true, + gamma: 0.99, + gae_lambda: 0.95, + norm_adv: true, + clip_coef: 0.2, + clip_vloss: true, + ent_coef: 0.1, + vf_coef: 0.5, + target_kl: None, + anneal_entropy: true, + ), + rollout: RolloutConfig( + steps: 1024, + num_envs: 1, + processes: 1, + ), + eval: None, + vf_net: None, + name: "config", + seed: 1, + total_timesteps: 2000, + max_train_time: None, + torch_deterministic: true, + cuda: true, + track: false, + wandb_project_name: "enn-ppo", + wandb_entity: "entity-neural-network", + capture_samples: None, + capture_logits: false, + capture_samples_subsample: 1, + trial: None, + data_dir: ".", + cuda_empty_cache: false, +) \ No newline at end of file diff --git a/train_kill_enemy_with_entity_env.py b/train_kill_enemy_with_entity_env.py index 0b4af73..2d80ad1 100644 --- a/train_kill_enemy_with_entity_env.py +++ b/train_kill_enemy_with_entity_env.py @@ -1,3 +1,4 @@ +import logging from enn_trainer import TrainConfig, State, init_train_state, train import hyperstate from common import logging_options @@ -14,8 +15,8 @@ def main(state_manager: hyperstate.StateManager) -> None: try: train(state_manager=state_manager, env=TowerfallEntityEnvImpl) finally: - towerfall_provider = TowerfallProcessProvider('entity-env-trainer') - towerfall_provider.close() + logging.info('Closing all Towerfall processes') + TowerfallProcessProvider.close_all() if __name__ == "__main__":