diff --git a/.gitmodules b/.gitmodules index db530b4e9..2d56e6e39 100644 --- a/.gitmodules +++ b/.gitmodules @@ -18,3 +18,7 @@ path = python/examples/experiments/rts-self-play url = https://github.com/Bam4d/rts-self-play ignore = dirty +[submodule "python/examples/experiments/autoregressive-cats"] + path = python/examples/experiments/autoregressive-cats + url = https://github.com/Bam4d/autoregressive-cats + ignore = dirty diff --git a/bindings/wrapper/GriddlyLoaderWrapper.cpp b/bindings/wrapper/GriddlyLoaderWrapper.cpp index 8b6fbadbf..29d8b5032 100644 --- a/bindings/wrapper/GriddlyLoaderWrapper.cpp +++ b/bindings/wrapper/GriddlyLoaderWrapper.cpp @@ -6,7 +6,6 @@ #include "../../src/Griddly/Core/GDY/Objects/ObjectGenerator.hpp" #include "../../src/Griddly/Core/GDY/TerminationGenerator.hpp" #include "../../src/Griddly/Core/Grid.hpp" -#include "../../src/Griddly/Core/Observers/Vulkan/VulkanObserver.hpp" #include "GDYWrapper.cpp" namespace griddly { diff --git a/docs/_static/video/.gitignore b/docs/_static/video/.gitignore new file mode 100644 index 000000000..9bd53b814 --- /dev/null +++ b/docs/_static/video/.gitignore @@ -0,0 +1 @@ +!*.mp4 \ No newline at end of file diff --git a/docs/_static/video/griddly_rts.mp4 b/docs/_static/video/griddly_rts.mp4 new file mode 100644 index 000000000..f5289708d Binary files /dev/null and b/docs/_static/video/griddly_rts.mp4 differ diff --git a/docs/about/halloffame.rst b/docs/about/halloffame.rst index e6c712e60..c3b70bb81 100644 --- a/docs/about/halloffame.rst +++ b/docs/about/halloffame.rst @@ -2,7 +2,7 @@ Hall of Fame ############ -If you create a project that uses Griddly, please let us know and we will link it here. This includes if you use Griddly in any papers, use the griddly engine in another game project and want to share your work. +If you create a project that uses Griddly, please let us know and we will link it here. This includes if you use Griddly in any papers, use the Griddly engine in another game project and want to share your work. .. note:: You can Be the first! diff --git a/docs/about/introduction.rst b/docs/about/introduction.rst index aef2218f1..fc5c0fa74 100644 --- a/docs/about/introduction.rst +++ b/docs/about/introduction.rst @@ -4,7 +4,7 @@ Introduction ############ -One of the most important things about AI research is data. In many Game Environments the rate of data (rendered frames per second, or state representations per second) is relatively slow meaning very long training times. Researchers can compensate for this problem by parallelising the number of games being played, sometimes on expensive hardward and sometimes on several servers requiring network infrastructure to pass states to the actual learning algorithms. For many researchers and hobbyists who want to learn. This approach is unobtainable and only the research teams with lots of funding and engineers supporting the hardware and infrastrcuture required. +One of the most important things about AI research is data. In many Game Environments the rate of data (rendered frames per second, or state representations per second) is relatively slow meaning very long training times. Researchers can compensate for this problem by parallelizing the number of games being played, sometimes on expensive hardware and sometimes on several servers requiring network infrastructure to pass states to the actual learning algorithms. For many researchers and hobbyists who want to learn. This approach is unobtainable and only the research teams with lots of funding and engineers supporting the hardware and infrastructure required. Griddly provides a solution to this issue. diff --git a/docs/getting-started/procedural content generation/img/Doggo-level-Sprite2D-0.png b/docs/getting-started/procedural content generation/img/Doggo-level-Sprite2D-0.png new file mode 100644 index 000000000..a75ad2120 Binary files /dev/null and b/docs/getting-started/procedural content generation/img/Doggo-level-Sprite2D-0.png differ diff --git a/docs/getting-started/procedural content generation/img/generated_clusters.png b/docs/getting-started/procedural content generation/img/generated_clusters.png new file mode 100644 index 000000000..c01d0ae29 Binary files /dev/null and b/docs/getting-started/procedural content generation/img/generated_clusters.png differ diff --git a/docs/getting-started/procedural content generation/index.rst b/docs/getting-started/procedural content generation/index.rst new file mode 100644 index 000000000..f4669e043 --- /dev/null +++ b/docs/getting-started/procedural content generation/index.rst @@ -0,0 +1,240 @@ +.. _doc_tutorials_pcg: + +############################# +Procedural Content Generation +############################# + +Reinforcement learning can be prone to over-fitting in environments where the initial conditions are limited and the environment dynamics are deterministic. +Procedural content generation is an important tool in Reinforcement learning, as it allows level maps to be created on-the-fly. This gives the agent a much more complex challenge, and stops it from being able to overfit on a small dataset of levels. + + +********** +Level Maps +********** + +Levels in Griddly environments are defined by strings of characters. The ``MapCharacter`` used are defined in the GDY files of the game. These ``MapCharacter`` can be found in the GDY files or in the game's documentation. + +Basic Map +========= + +.. code-block:: python + + W W W W W W + W A . . . W + W . . . . W + W . . . . W + W . . . g W + W W W W W W + +.. figure:: img/Doggo-level-Sprite2D-0.png + :align: center + + How the above Doggo level is rendered. + + +You can see in this map example above that the ``A`` character defines the Dog and the ``g`` character defines the goal. ``W`` defines the walls and ``.`` is reserved for empty space. + +This is a basic example and generating levels for this environment might not be too interesting... + + +************************ +Clusters Level Generator +************************ + +A much more complicated example would be to use the `Clusters` game and generate new levels. The aim of the Clusters game is for the agent to push coloured blocks together to form "clusters", whilst avoiding spikes. +The game is fully deterministic and there are only 5 levels supplied in the original GDY file. This makes it a perfect candidate for building new levels and testing if Reinforcement Learning can still solve these levels! + + +Level Generator Class +===================== + +Here's an example of a level generator for the cluster's game. + +The ``LevelGenerator`` class can be used as a base class. Only the ``generate`` function needs to be implemented. + +.. code-block:: python + + class ClustersLevelGenerator(LevelGenerator): + BLUE_BLOCK = 'a' + BLUE_BOX = '1' + RED_BLOCK = 'b' + RED_BOX = '2' + GREEN_BLOCK = 'c' + GREEN_BOX = '3' + + AGENT = 'A' + + WALL = 'w' + SPIKES = 'h' + + def __init__(self, config): + super().__init__(config) + self._width = config.get('width', 10) + self._height = config.get('height', 10) + self._p_red = config.get('p_red', 1.0) + self._p_green = config.get('p_green', 1.0) + self._p_blue = config.get('p_blue', 1.0) + self._m_red = config.get('m_red', 5) + self._m_blue = config.get('m_blue', 5) + self._m_green = config.get('m_green', 5) + self._m_spike = config.get('m_spike', 5) + + def _place_walls(self, map): + + # top/bottom wall + wall_y = np.array([0, self._height - 1]) + map[:, wall_y] = ClustersLevelGenerator.WALL + + # left/right wall + wall_x = np.array([0, self._width - 1]) + map[wall_x, :] = ClustersLevelGenerator.WALL + + return map + + def _place_blocks_and_boxes(self, map, possible_locations, p, block_char, box_char, max_boxes): + if np.random.random() < p: + block_location_idx = np.random.choice(len(possible_locations)) + block_location = possible_locations[block_location_idx] + del possible_locations[block_location_idx] + map[block_location[0], block_location[1]] = block_char + + num_boxes = 1 + np.random.choice(max_boxes - 1) + for k in range(num_boxes): + box_location_idx = np.random.choice(len(possible_locations)) + box_location = possible_locations[box_location_idx] + del possible_locations[box_location_idx] + map[box_location[0], box_location[1]] = box_char + + return map, possible_locations + + def generate(self): + map = np.chararray((self._width, self._height), itemsize=2) + map[:] = '.' + + # Generate walls + map = self._place_walls(map) + + # all possible locations + possible_locations = [] + for w in range(1, self._width - 1): + for h in range(1, self._height - 1): + possible_locations.append([w, h]) + + # Place Red + map, possible_locations = self._place_blocks_and_boxes( + map, + possible_locations, + self._p_red, + ClustersLevelGenerator.RED_BLOCK, + ClustersLevelGenerator.RED_BOX, + self._m_red + ) + + # Place Blue + map, possible_locations = self._place_blocks_and_boxes( + map, + possible_locations, + self._p_blue, + ClustersLevelGenerator.BLUE_BLOCK, + ClustersLevelGenerator.BLUE_BOX, + self._m_blue + ) + + # Place Green + map, possible_locations = self._place_blocks_and_boxes( + map, + possible_locations, + self._p_green, + ClustersLevelGenerator.GREEN_BLOCK, + ClustersLevelGenerator.GREEN_BOX, + self._m_green + ) + + # Place Spikes + num_spikes = np.random.choice(self._m_spike) + for k in range(num_spikes): + spike_location_idx = np.random.choice(len(possible_locations)) + spike_location = possible_locations[spike_location_idx] + del possible_locations[spike_location_idx] + map[spike_location[0], spike_location[1]] = ClustersLevelGenerator.SPIKES + + # Place Agent + agent_location_idx = np.random.choice(len(possible_locations)) + agent_location = possible_locations[agent_location_idx] + map[agent_location[0], agent_location[1]] = ClustersLevelGenerator.AGENT + + level_string = '' + for h in range(0, self._height): + for w in range(0, self._width): + level_string += map[w, h].decode().ljust(4) + level_string += '\n' + + return level_string + +This generates levels like the following: + +.. figure:: img/generated_clusters.png + :align: center + + A 10x10 map generated by the above code. + + + +Using ``LevelGenerator`` +======================== + +In the most simple case, the level generator can be used just before the level resets and the generated string can be passed to ``env.reset(level_string=...)`` + +.. code-block:: python + + if __name__ == '__main__': + + config = { + 'width': 10, + 'height': 10 + } + + renderer = RenderToFile() + + level_generator = ClustersLevelGenerator(config) + + env = gym.make('GDY-Clusters-v0') + env.reset(level_string=level_generator.generate()) + + ... + + + +Using ``LevelGenerators`` with RLLib +==================================== + +The ``LevelGenerator`` base class is compatible with RLLib and can be used and configured through the standard RLLib configuration. + +For example, the level generator and its parameters can be set up in the ``env_config`` in the following way: + +.. code-block:: python + + 'config': { + + ... + + 'env_config': { + 'generate_valid_action_trees': True, + 'level_generator': { + 'class': ClustersLevelGenerator, + 'config': { + 'width': 6, + 'height': 6, + 'p_red': 0.7, + 'p_green': 0.7, + 'p_blue': 0.7, + 'm_red': 4, + 'm_blue': 4, + 'm_green': 4, + 'm_spike': 4 + } + }, + + ... + } + diff --git a/docs/index.rst b/docs/index.rst index 11b25b705..553cacbda 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Griddly documentation. getting-started/action spaces/index getting-started/observation spaces/index getting-started/visualization/index + getting-started/procedural content generation/index .. toctree:: :maxdepth: 2 diff --git a/python/examples/experiments/autoregressive-cats b/python/examples/experiments/autoregressive-cats new file mode 160000 index 000000000..d2143f8f8 --- /dev/null +++ b/python/examples/experiments/autoregressive-cats @@ -0,0 +1 @@ +Subproject commit d2143f8f844fe30f60dc458b5d7286e36db0d54f diff --git a/python/examples/experiments/conditional-action-trees b/python/examples/experiments/conditional-action-trees index 0fd2cdf11..08e7574cd 160000 --- a/python/examples/experiments/conditional-action-trees +++ b/python/examples/experiments/conditional-action-trees @@ -1 +1 @@ -Subproject commit 0fd2cdf11c2785aa469fc0704be2aff7a3807974 +Subproject commit 08e7574cd95b0a5714017e1ca78a12a32c38f645 diff --git a/python/examples/experiments/rts-self-play b/python/examples/experiments/rts-self-play index a8523959f..c057f0233 160000 --- a/python/examples/experiments/rts-self-play +++ b/python/examples/experiments/rts-self-play @@ -1 +1 @@ -Subproject commit a8523959f6ff2de5b158428a5246b9c6385b3170 +Subproject commit c057f0233241feaddb58f4f1a29dd6b42202b418 diff --git a/python/griddly/GymWrapper.py b/python/griddly/GymWrapper.py index 72fd76f0e..314597dd1 100644 --- a/python/griddly/GymWrapper.py +++ b/python/griddly/GymWrapper.py @@ -110,7 +110,7 @@ def step(self, action): elif len(action) == self.player_count: if np.ndim(action) == 1 or np.ndim(action) == 3: - if isinstance(action[0], list) or isinstance(action[0], np.ndarray): + if isinstance(action[0], list) or isinstance(action[0], np.ndarray) or isinstance(action[0], tuple): # Multiple agents that can perform multiple actions in parallel # Used in RTS games reward = [] diff --git a/python/griddly/util/rllib/callbacks.py b/python/griddly/util/rllib/callbacks.py index 5684d2e73..72db331e6 100644 --- a/python/griddly/util/rllib/callbacks.py +++ b/python/griddly/util/rllib/callbacks.py @@ -8,59 +8,9 @@ from wandb import Video -class MultiCallback(DefaultCallbacks): - - def __init__(self, callback_class_list): - super().__init__() - self._callback_class_list = callback_class_list - - self._callback_list = [] - - def __call__(self, *args, **kwargs): - self._callback_list = [callback_class() for callback_class in self._callback_class_list] - - return self - - def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, env_index: Optional[int] = None, **kwargs) -> None: - for callback in self._callback_list: - callback.on_episode_start(worker=worker, base_env=base_env, policies=policies, episode=episode, - env_index=env_index, **kwargs) - - def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv, episode: MultiAgentEpisode, - env_index: Optional[int] = None, **kwargs) -> None: - for callback in self._callback_list: - callback.on_episode_step(worker=worker, base_env=base_env, episode=episode, env_index=env_index, **kwargs) - - def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, env_index: Optional[int] = None, **kwargs) -> None: - for callback in self._callback_list: - callback.on_episode_end(worker=worker, base_env=base_env, policies=policies, episode=episode, - env_index=env_index, **kwargs) - - def on_postprocess_trajectory(self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID, - policy_id: PolicyID, policies: Dict[PolicyID, Policy], - postprocessed_batch: SampleBatch, original_batches: Dict[AgentID, SampleBatch], - **kwargs) -> None: - for callback in self._callback_list: - callback.on_postprocess_trajectory(worker=worker, episode=episode, agent_id=agent_id, policy_id=policy_id, - policies=policies, postprocessed_batch=postprocessed_batch, - original_batches=original_batches, **kwargs) - - def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs) -> None: - for callback in self._callback_list: - callback.on_sample_end(worker=worker, samples=samples, **kwargs) - - def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs) -> None: - for callback in self._callback_list: - callback.on_learn_on_batch(policy=policy, train_batch=train_batch, result=result, **kwargs) - - def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: - for callback in self._callback_list: - callback.on_train_result(trainer=trainer, result=result, **kwargs) - - -class VideoCallback(DefaultCallbacks): +class GriddlyRLLibCallbacks(DefaultCallbacks): + """Contains helper functions for Griddly callbacks + """ def _get_envs(self, base_env): if isinstance(base_env, _VectorEnvToBaseEnv): @@ -68,6 +18,16 @@ def _get_envs(self, base_env): else: return base_env.envs + def _get_player_ids(self, base_env, env_index): + envs = self._get_envs(base_env) + player_count = envs[env_index].player_count + if player_count == 1: + return ['agent0'] + else: + return [p for p in range(1, player_count+1)] + +class VideoCallbacks(GriddlyRLLibCallbacks): + def on_episode_start(self, *, worker: "RolloutWorker", @@ -100,19 +60,13 @@ def on_episode_end(self, episode.media[f'level_{level}_1'] = Video(path) -class ActionTrackerCallback(DefaultCallbacks): +class ActionTrackerCallbacks(GriddlyRLLibCallbacks): def __init__(self): super().__init__() self._action_frequency_trackers = {} - def _get_envs(self, base_env): - if isinstance(base_env, _VectorEnvToBaseEnv): - return base_env.vector_env.envs - else: - return base_env.envs - def on_episode_start(self, *, worker: "RolloutWorker", @@ -121,10 +75,8 @@ def on_episode_start(self, episode: MultiAgentEpisode, env_index: Optional[int] = None, **kwargs) -> None: - envs = self._get_envs(base_env) - num_players = envs[env_index].player_count self._action_frequency_trackers[episode.episode_id] = [] - for p in range(0, num_players): + for _ in self._get_player_ids(base_env, env_index): self._action_frequency_trackers[episode.episode_id].append(Counter()) def on_episode_step(self, @@ -135,11 +87,8 @@ def on_episode_step(self, env_index: Optional[int] = None, **kwargs) -> None: - envs = self._get_envs(base_env) - num_players = envs[env_index].player_count - - for p in range(0, num_players): - info = episode.last_info_for(p+1) + for p, id in enumerate(self._get_player_ids(base_env, env_index)): + info = episode.last_info_for(id) if 'History' in info: history = info['History'] for event in history: @@ -149,11 +98,30 @@ def on_episode_step(self, def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: Optional[int] = None, **kwargs) -> None: - envs = self._get_envs(base_env) - num_players = envs[env_index].player_count - - for p in range(0, num_players): + for p, id in enumerate(self._get_player_ids(base_env, env_index)): for action_name, frequency in self._action_frequency_trackers[episode.episode_id][p].items(): - episode.custom_metrics[f'agent_info/{p+1}/{action_name}'] = frequency + episode.custom_metrics[f'agent_info/{id}/{action_name}'] = frequency del self._action_frequency_trackers[episode.episode_id] + +class WinLoseMetricCallbacks(GriddlyRLLibCallbacks): + + def __init__(self): + super().__init__() + + def on_episode_end(self, + *, + worker: "RolloutWorker", + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode: MultiAgentEpisode, + env_index: Optional[int] = None, + **kwargs) -> None: + + for p, id in enumerate(self._get_player_ids(base_env, env_index)): + info = episode.last_info_for(id) + episode.custom_metrics[f'agent_info/{id}/win'] = 1 if info['PlayerResults'][f'{p+1}'] == 'Win' else 0 + episode.custom_metrics[f'agent_info/{id}/lose'] = 1 if info['PlayerResults'][f'{p+1}'] == 'Lose' else 0 + episode.custom_metrics[f'agent_info/{id}/end'] = 1 if info['PlayerResults'][f'{p+1}'] == 'End' else 0 + + diff --git a/python/griddly/util/rllib/environment/core.py b/python/griddly/util/rllib/environment/core.py index 3eddaf55c..cb8ced4f9 100644 --- a/python/griddly/util/rllib/environment/core.py +++ b/python/griddly/util/rllib/environment/core.py @@ -65,10 +65,15 @@ def __init__(self, env_config): self.generate_valid_action_trees = env_config.get('generate_valid_action_trees', False) self._random_level_on_reset = env_config.get('random_level_on_reset', False) + level_generator_rllib_config = env_config.get('level_generator', None) - super().reset() + self._level_generator = None + if level_generator_rllib_config is not None: + level_generator_class = level_generator_rllib_config['class'] + level_generator_config = level_generator_rllib_config['config'] + self._level_generator = level_generator_class(level_generator_config) - self.set_transform() + self.reset() self.enable_history(self.record_actions) @@ -85,8 +90,8 @@ def _after_step(self, observation, reward, done, info): extra_info = {} # If we are in a multi-agent setting then we handle videos elsewhere - if self.is_video_enabled(): - if self.player_count == 1: + if self.player_count == 1: + if self.is_video_enabled(): videos_list = [] if self.include_agent_videos: video_info = self._agent_recorder.step(self.level_id, self.env_steps, done) @@ -127,8 +132,11 @@ def _get_valid_action_trees(self): def reset(self, **kwargs): - if self._random_level_on_reset: + if self._level_generator is not None: + kwargs['level_string'] = self._level_generator.generate() + elif self._random_level_on_reset: kwargs['level_id'] = np.random.choice(self.level_count) + observation = super().reset(**kwargs) self.set_transform() @@ -142,8 +150,7 @@ def step(self, action): extra_info = self._after_step(observation, reward, done, info) - if 'videos' in extra_info: - info['videos'] = extra_info['videos'] + info.update(extra_info) if self.generate_valid_action_trees: self.last_valid_action_trees = self._get_valid_action_trees() @@ -218,7 +225,6 @@ def _resolve_player_done_variable(self): def _after_step(self, obs_map, reward_map, done_map, info_map): extra_info = {} - if self.is_video_enabled(): videos_list = [] if self.include_agent_videos: @@ -236,7 +242,7 @@ def _after_step(self, obs_map, reward_map, done_map, info_map): return extra_info def step(self, action_dict: MultiAgentDict): - actions_array = np.zeros((self.player_count, *self.action_space.shape)) + actions_array = [None] * self.player_count for agent_id, action in action_dict.items(): actions_array[agent_id - 1] = action diff --git a/python/griddly/util/rllib/environment/level_generator.py b/python/griddly/util/rllib/environment/level_generator.py new file mode 100644 index 000000000..c7ae47da5 --- /dev/null +++ b/python/griddly/util/rllib/environment/level_generator.py @@ -0,0 +1,7 @@ +class LevelGenerator: + + def __init__(self, config): + self._config = config + + def generate(self): + raise NotImplementedError() \ No newline at end of file diff --git a/python/griddly/util/rllib/torch/conditional_actions/conditional_action_exploration.py b/python/griddly/util/rllib/torch/conditional_actions/conditional_action_exploration.py index 4e6842d3b..61298d426 100644 --- a/python/griddly/util/rllib/torch/conditional_actions/conditional_action_exploration.py +++ b/python/griddly/util/rllib/torch/conditional_actions/conditional_action_exploration.py @@ -8,8 +8,7 @@ class TorchConditionalMaskingExploration(): - def __init__(self, model, dist_inputs, valid_action_trees, explore=False, invalid_action_masking='conditional', - allow_nop=False): + def __init__(self, model, dist_inputs, valid_action_trees, explore=False): self._valid_action_trees = valid_action_trees self._num_inputs = dist_inputs.shape[0] @@ -23,9 +22,6 @@ def __init__(self, model, dist_inputs, valid_action_trees, explore=False, invali self.model = model - self._invalid_action_masking = invalid_action_masking - self._allow_nop = allow_nop - self.device = dist_inputs.device self._explore = explore diff --git a/python/griddly/util/rllib/torch/conditional_actions/conditional_action_mixin.py b/python/griddly/util/rllib/torch/conditional_actions/conditional_action_mixin.py index 68f71a869..18ac6b45b 100644 --- a/python/griddly/util/rllib/torch/conditional_actions/conditional_action_mixin.py +++ b/python/griddly/util/rllib/torch/conditional_actions/conditional_action_mixin.py @@ -46,8 +46,6 @@ def compute_actions_from_input_dict( seq_lens) generate_valid_action_trees = self.config['env_config'].get('generate_valid_action_trees', False) - invalid_action_masking = self.config["env_config"].get("invalid_action_masking", 'none') - allow_nop = self.config["env_config"].get("allow_nop", False) extra_fetches = {} @@ -65,9 +63,7 @@ def compute_actions_from_input_dict( self.model, dist_inputs, valid_action_trees, - explore, - invalid_action_masking, - allow_nop + explore ) actions, masked_logits, logp, mask = exploration.get_actions_and_mask() diff --git a/python/griddly/util/rllib/torch/conditional_actions/conditional_action_policy_trainer.py b/python/griddly/util/rllib/torch/conditional_actions/conditional_action_policy_trainer.py index c3e4f9407..40bac0a26 100644 --- a/python/griddly/util/rllib/torch/conditional_actions/conditional_action_policy_trainer.py +++ b/python/griddly/util/rllib/torch/conditional_actions/conditional_action_policy_trainer.py @@ -3,18 +3,15 @@ import torch from ray.rllib import SampleBatch from ray.rllib.agents.impala import ImpalaTrainer -from ray.rllib.agents.impala.vtrace_torch_policy import build_vtrace_loss from ray.rllib.agents.impala.vtrace_torch_policy import VTraceTorchPolicy, VTraceLoss, make_time_major from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.torch_policy import LearningRateSchedule, EntropyCoeffSchedule -from tensorflow import sequence_mask +from ray.rllib.utils.torch_ops import sequence_mask from griddly.util.rllib.torch.conditional_actions.conditional_action_mixin import ConditionalActionMixin def build_invalid_masking_vtrace_loss(policy, model, dist_class, train_batch): - if not policy.config['env_config'].get('vtrace_masking', False): - return build_vtrace_loss(policy, model, dist_class, train_batch) model_out, _ = model.from_batch(train_batch) @@ -47,7 +44,7 @@ def _make_time_major(*args, **kw): else: mask = torch.ones_like(rewards) - model_out += torch.log(invalid_action_mask) + model_out += torch.maximum(torch.tensor(torch.finfo().min), torch.log(invalid_action_mask)) action_dist = dist_class(model_out, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): @@ -116,5 +113,6 @@ def get_vtrace_policy_class(config): raise NotImplementedError('Tensorflow not supported') -ConditionalActionImpalaTrainer = ImpalaTrainer.with_updates(default_policy=ConditionalActionVTraceTorchPolicy, +ConditionalActionImpalaTrainer = ImpalaTrainer.with_updates(name="ConditionalActionImpalaTrainer", + default_policy=ConditionalActionVTraceTorchPolicy, get_policy_class=get_vtrace_policy_class)