Skip to content

Commit

Permalink
More experiments with entitygym
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanaa committed May 8, 2023
1 parent 174657e commit dd8cd2b
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 45 deletions.
21 changes: 16 additions & 5 deletions entity_envs/entity_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,28 @@
from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP
from common.entity import Entity, to_entities

from envs.connection_provider import TowerfallProcess
from envs.connection_provider import TowerfallProcess, TowerfallProcessProvider

class TowerfallEntityEnv(Environment):
def __init__(self,
towerfall: TowerfallProcess,
towerfall: Optional[TowerfallProcess] = None,
record_path: Optional[str] = None,
verbose: int = 0):
logging.info('Initializing TowerfallEntityEnv')
self.towerfall = towerfall
self.verbose = verbose
self.connection = self.towerfall.join(timeout=5, verbose=self.verbose)
if towerfall:
self.connection = towerfall.join(timeout=5, verbose=self.verbose)
self.towerfall = towerfall
else:
self.connection, self.towerfall = TowerfallProcessProvider().join_new(
fastrun=True,
# nographics=True,
config=dict(
mode='sandbox',
level='3',
agents=[dict(type='remote', team='blue', archer='green')]),
timeout=5,
verbose=self.verbose)
self.connection.record_path = record_path
self._draw_elems = []
self.is_init_sent = False
Expand All @@ -27,7 +38,7 @@ def __init__(self,

def _is_reset_valid(self) -> bool:
'''
Use this to make check if the initiallization is valid.
Use this to check if the initiallization is valid.
This is useful to collect information about the environment to programmatically construct a sequence of tasks, then
reset the environment again with the proper reseet instructions.
Expand Down
16 changes: 3 additions & 13 deletions entity_envs/entity_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,7 @@ class TowerfallEntityEnvImpl(TowerfallEntityEnv):
def __init__(self,
record_path: Optional[str]=None,
verbose: int = 0):
towerfall_provider = TowerfallProcessProvider('entity-env-trainer')
towerfall = towerfall_provider.get_process(
fastrun=True,
reuse=False,
config=dict(
mode='sandbox',
level='3',
agents=[dict(type='remote', team='blue', archer='green')]
))
super().__init__(towerfall, record_path, verbose)
super().__init__(record_path=record_path, verbose=verbose)
self.enemy_count = 2
self.min_distance = 50
self.max_distance = 100
Expand Down Expand Up @@ -101,9 +92,8 @@ def _update_reward(self, enemies: list[Entity]):
self.reward += delta_arrow * 0.1
self.prev_arrow_count = arrow_count


if self.reward != 0:
logging.info(f'Reward: {self.reward}')
# if self.reward != 0:
# logging.info(f'Reward: {self.reward}')

self.prev_enemy_ids = enemy_ids
if len(self.prev_enemy_ids) == 0:
Expand Down
101 changes: 76 additions & 25 deletions envs/connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from common import Connection

from typing import Any, Optional
from typing import Any, Optional, Tuple

from common.namedmutex import NamedMutex


_HOST = '127.0.0.1'
Expand Down Expand Up @@ -88,26 +90,30 @@ class TowerfallProcessProvider:
params name: Name of the connection provider. Used to separate different connection providers states.
'''
def __init__(self, name: str):
self.towerfall_path = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall'
def __init__(self, name: Optional[str] = None,
# towerfall_path: str = 'C:/Users/vcanaa/towerfall/TowerFall',
towerfall_path: str = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall'):
self.towerfall_path = towerfall_path
self.towerfall_path_exe = os.path.join(self.towerfall_path, 'TowerFall.exe')
self.connection_path = os.path.join('.connection_provider', name)
os.makedirs(self.connection_path, exist_ok=True)
self.state_path = os.path.join(self.connection_path, 'state.json')
self.name = name
self.processes = []
if os.path.exists(self.state_path):
with open(self.state_path, 'r') as file:
for process_data in json.loads(file.read()):
try:
psutil.Process(process_data['pid'])
self.processes.append(TowerfallProcess(**process_data))
except psutil.NoSuchProcess:
continue
if self.name:
self.connection_path = os.path.join('.connection_provider', self.name)
os.makedirs(self.connection_path, exist_ok=True)
self.state_path = os.path.join(self.connection_path, 'state.json')

if os.path.exists(self.state_path):
with open(self.state_path, 'r') as file:
for process_data in json.loads(file.read()):
try:
psutil.Process(process_data['pid'])
self.processes.append(TowerfallProcess(**process_data))
except psutil.NoSuchProcess:
continue
self._save_state()

self._processes_in_use = set()

self._save_state()

self.default_config = dict(
mode='sandbox',
level='2',
Expand Down Expand Up @@ -136,19 +142,23 @@ def is_suitable_process(process: TowerfallProcess):
return True
selected_process = next((p for p in self.processes if is_suitable_process(p)), None)

# If no process was found, start a new one
# If no process can be reused, start a new one
if not selected_process:
logging.info(f'Starting new process {self.towerfall_path_exe}')
pargs = [self.towerfall_path_exe, '--noconfig']
if fastrun:
pargs.append('--fastrun')
if nographics:
pargs.append('--nographics')
process = Popen(pargs, cwd=self.towerfall_path)
port = self._get_port(process.pid)
selected_process = TowerfallProcess(process.pid, port, fastrun, nographics)
self.processes.append(selected_process)
self._save_state()
# Multiple TowerFall.exe can't be started at the same time, due to conflict accessing Content folder.
with NamedMutex(f'TowerfallProcessProvider_{self.name}'):
process = Popen(pargs, cwd=self.towerfall_path)
port = self._get_port(process.pid)
selected_process = TowerfallProcess(process.pid, port, fastrun, nographics)
self.processes.append(selected_process)
self._save_state()
time.sleep(2) # Give some time for game to load. There is currently no way to tell if the game loaded.


try:
selected_process.send_config(config, verbose=verbose)
Expand All @@ -162,15 +172,55 @@ def is_suitable_process(process: TowerfallProcess):
def release_process(self, process: TowerfallProcess):
self._processes_in_use.remove(process.pid)

def join_new(self, fastrun=True, nographics=False, config = None, timeout=2, verbose=0) -> Tuple[Connection, TowerfallProcess]:
connection = None
process = None
logging.info('Create a new process and join')
while not connection or not process:
try:
process = self.get_process(fastrun, nographics, config, verbose, reuse=False)
connection = process.join(timeout, verbose)
except Exception as ex:
logging.error(f'Failed to create and join new process: {ex}')
if process:
self.kill_process(process.pid)
if connection:
connection.close()

return connection, process

def kill_process(self, pid):
try:
os.kill(pid, signal.SIGTERM)
except Exception as ex:
logging.error(f'Failed to kill process {pid}: {ex}')
finally:
self.processes.remove(next(p for p in self.processes if p.pid == pid))
self._save_state()

def close(self):
logging.info('Closing all processes...')
logging.info(f'Closing all processes in context {self.name}...')
for process in self.processes:
try:
os.kill(process.pid, signal.SIGTERM)
except Exception as ex:
logging.error(f'Failed to kill process {process.pid}: {ex}')
continue

@classmethod
def close_all(cls):
logging.info('Closing all TowerFall.exe processes...')
for process in psutil.process_iter(attrs=['pid', 'name']):
# logging.info(f'Checking process {process.pid} {process.name()}')
if process.name() != 'TowerFall.exe':
continue
try:
logging.info(f'Killing process {process.pid}...')
os.kill(process.pid, signal.SIGTERM)
except Exception as ex:
logging.error(f'Failed to kill process {process.pid}: {ex}')
continue

def _get_port(self, pid: int) -> int:
port_path = os.path.join(self.towerfall_path, 'ports', str(pid))
tries = 0
Expand All @@ -182,8 +232,9 @@ def _get_port(self, pid: int) -> int:
return int(file.readline())

def _save_state(self):
with open(self.state_path, 'w') as file:
file.write(json.dumps([p.to_dict() for p in self.processes], indent=2))
if self.name:
with open(self.state_path, 'w') as file:
file.write(json.dumps([p.to_dict() for p in self.processes], indent=2))

def _match_config(self, config1, config2):
return False
7 changes: 7 additions & 0 deletions train commands
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
python train_kill_enemy_with_entity_env.py ^
total_timesteps=2000 rollout.steps=1024 ^
rollout.num_envs=1 ^
rollout.processes=1 ^
optim.bs=256 ^
optim.lr=0.0001 ^
--checkpoint-dir=checkpoints
64 changes: 64 additions & 0 deletions train_kill_enemies.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
TrainConfig(
version: 4,
env: EnvConfig(
kwargs: "{}",
id: "MoveToOrigin",
validate: true,
),
net: RogueNetConfig(
embd_pdrop: 0.0,
resid_pdrop: 0.0,
attn_pdrop: 0.0,
n_layer: 2,
n_head: 2,
d_model: 32,
pooling: None,
relpos_encoding: None,
d_qk: 16,
translation: None,
),
optim: OptimizerConfig(
lr: 0.0001,
lr_warmup_steps: None,
bs: 256,
weight_decay: 0.0,
micro_bs: None,
anneal_lr: true,
update_epochs: 3,
max_grad_norm: 2.0,
),
ppo: PPOConfig(
gae: true,
gamma: 0.99,
gae_lambda: 0.95,
norm_adv: true,
clip_coef: 0.2,
clip_vloss: true,
ent_coef: 0.1,
vf_coef: 0.5,
target_kl: None,
anneal_entropy: true,
),
rollout: RolloutConfig(
steps: 1024,
num_envs: 1,
processes: 1,
),
eval: None,
vf_net: None,
name: "config",
seed: 1,
total_timesteps: 2000,
max_train_time: None,
torch_deterministic: true,
cuda: true,
track: false,
wandb_project_name: "enn-ppo",
wandb_entity: "entity-neural-network",
capture_samples: None,
capture_logits: false,
capture_samples_subsample: 1,
trial: None,
data_dir: ".",
cuda_empty_cache: false,
)
5 changes: 3 additions & 2 deletions train_kill_enemy_with_entity_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from enn_trainer import TrainConfig, State, init_train_state, train
import hyperstate
from common import logging_options
Expand All @@ -14,8 +15,8 @@ def main(state_manager: hyperstate.StateManager) -> None:
try:
train(state_manager=state_manager, env=TowerfallEntityEnvImpl)
finally:
towerfall_provider = TowerfallProcessProvider('entity-env-trainer')
towerfall_provider.close()
logging.info('Closing all Towerfall processes')
TowerfallProcessProvider.close_all()


if __name__ == "__main__":
Expand Down

0 comments on commit dd8cd2b

Please sign in to comment.