diff --git a/python/examples/gym/gym_human_player.py b/python/examples/gym/gym_human_player.py index 2d0bad1e1..0588ab47a 100644 --- a/python/examples/gym/gym_human_player.py +++ b/python/examples/gym/gym_human_player.py @@ -34,16 +34,16 @@ def _callback(prev_obs, obs, action, rew, env_done, info): environment_name = 'TestEnv' # yaml_path = 'Single-Player/GVGAI/bait.yaml' - #yaml_path = 'Single-Player/GVGAI/butterflies.yaml' + # yaml_path = 'Single-Player/GVGAI/butterflies.yaml' # yaml_path = 'Single-Player/GVGAI/random_butterflies.yaml' # yaml_path = 'Single-Player/GVGAI/bait_keys.yaml' # yaml_path = 'Single-Player/Mini-Grid/minigrid-drunkdwarf.yaml' - # yaml_path = 'Single-Player/Mini-Grid/minigrid-spiders.yaml' + yaml_path = 'Single-Player/Mini-Grid/minigrid-spiders.yaml' # yaml_path = 'Single-Player/GVGAI/spider-nest.yaml' # yaml_path = 'Single-Player/GVGAI/cookmepasta.yaml' # yaml_path = 'Single-Player/GVGAI/clusters.yaml' # yaml_path = 'Single-Player/GVGAI/zenpuzzle.yaml' - yaml_path = 'Single-Player/GVGAI/sokoban.yaml' + # yaml_path = 'Single-Player/GVGAI/sokoban.yaml' # yaml_path = 'Single-Player/GVGAI/sokoban2.yaml' # yaml_path = 'Single-Player/GVGAI/sokoban2_partially_observable.yaml' # yaml_path = 'Single-Player/GVGAI/cookmepasta_partially_observable.yaml' diff --git a/python/examples/snippet.py b/python/examples/snippet.py index e4c18ee0a..83bda594a 100644 --- a/python/examples/snippet.py +++ b/python/examples/snippet.py @@ -1,18 +1,76 @@ +from timeit import default_timer as timer +import numpy as np import gym -import griddly -from griddly import gd + +from griddly import GymWrapperFactory, gd +from griddly.RenderTools import VideoRecorder from griddly.util.wrappers import InvalidMaskingRTSWrapper if __name__ == '__main__': + wrapper = GymWrapperFactory() + + wrapper.build_gym_from_yaml("GriddlyRTS-Adv", + 'RTS/Stratega/kill-the-king.yaml', + global_observer_type=gd.ObserverType.VECTOR, + player_observer_type=gd.ObserverType.VECTOR, + level=0) + + env_original = gym.make(f'GDY-GriddlyRTS-Adv-v0') + # env_original = gym.make(f'GDY-GriddlyRTS-Adv-v0') + + env_original.reset() + + env = InvalidMaskingRTSWrapper(env_original) + + start = timer() + + frames = 0 + + fps_samples = [] + + player1_recorder = VideoRecorder() + player1_visualization = env.render(observer=0, mode='rgb_array') + player1_recorder.start("player1_video_test.mp4", player1_visualization.shape) + + player2_recorder = VideoRecorder() + player2_visualization = env.render(observer=1, mode='rgb_array') + player2_recorder.start("player2_video_test.mp4", player2_visualization.shape) + + global_recorder = VideoRecorder() + global_visualization = env.render(observer='global', mode='rgb_array') + global_recorder.start("global_video_test.mp4", global_visualization.shape) + + for s in range(10000): + + frames += 1 + + action = env.action_space.sample() + + obs, reward, done, info = env.step(action) + + global_observation = env.render(mode='rgb_array', observer='global') + player1_observation = env.render(observer=0, mode='rgb_array') + player2_observation = env.render(observer=1, mode='rgb_array') + + global_recorder.add_frame(global_observation) + player1_recorder.add_frame(player1_observation) + player2_recorder.add_frame(player2_observation) - env = gym.make('GDY-Heal-Or-Die-v0', level=1, global_observer_type=gd.ObserverType.ISOMETRIC) - env.reset() - env = InvalidMaskingRTSWrapper(env) + if done: + #state = env.get_state() + #print(state) + print(info) - # Replace with your own control algorithm! - for s in range(1000): - obs, reward, done, info = env.step(env.action_space.sample()) - for p in range(env.player_count): - env.render(observer=p) # Renders the environment from the perspective of a single player + if frames % 1000 == 0: + end = timer() + fps = (frames / (end - start)) + fps_samples.append(fps) + print(f'fps: {fps}') + frames = 0 + start = timer() - env.render(observer='global') # Renders the entire environment \ No newline at end of file + # Have to close the video recorders + player1_recorder.close() + player2_recorder.close() + global_recorder.close() + print(f'mean fps: {np.mean(fps_samples)}') diff --git a/python/griddly/GymWrapper.py b/python/griddly/GymWrapper.py index 4639610f2..c910ec295 100644 --- a/python/griddly/GymWrapper.py +++ b/python/griddly/GymWrapper.py @@ -129,7 +129,9 @@ def step(self, action): f'A valid example: {self.action_space.sample()}') for p in range(self.player_count): - self._player_last_observation[p] = np.array(self._players[p].observe(), copy=True) + # Copy only if the environment is done (it will reset itself) + # This is because the underlying data will be released + self._player_last_observation[p] = np.array(self._players[p].observe(), copy=False) obs = self._player_last_observation[0] if self.player_count == 1 else self._player_last_observation @@ -157,10 +159,11 @@ def reset(self, level_id=None, level_string=None, global_observations=False): return self._player_last_observation[0] if self.player_count == 1 else self._player_last_observation def initialize_spaces(self): + self._player_last_observation = [] for p in range(self.player_count): - self._player_last_observation.append(np.array(self._players[p].observe(), copy=True)) + self._player_last_observation.append(np.array(self._players[p].observe(), copy=False)) - self._global_last_observation = np.array(self.game.observe(), copy=True) + self._global_last_observation = np.array(self.game.observe(), copy=False) self.player_observation_shape = self._player_last_observation[0].shape self.global_observation_shape = self._global_last_observation.shape @@ -183,7 +186,7 @@ def initialize_spaces(self): def render(self, mode='human', observer=0): if observer == 'global': - observation = np.array(self.game.observe(), copy=True) + observation = np.array(self.game.observe(), copy=False) if self._global_observer_type == gd.ObserverType.VECTOR: observation = self._vector2rgb.convert(observation) else: