Skip to content

Commit

Permalink
refactor and working mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Bam4d committed Oct 11, 2023
1 parent 6e50a81 commit f1fde98
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 433 deletions.
2 changes: 1 addition & 1 deletion python/griddly/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def enable_history(self, enable: bool = True) -> None:
self.game.enable_history(enable)

def step( # type: ignore
self, action: Action
self, action: Union[Action, List[Action]]
) -> Tuple[
Union[List[Observation], Observation],
Union[List[int], int],
Expand Down
98 changes: 98 additions & 0 deletions python/griddly/util/rllib/environment/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union, Set

import numpy as np
import numpy.typing as npt
from ray.rllib import MultiAgentEnv
from ray.rllib.utils.typing import MultiAgentDict

from griddly.gym import GymWrapper
from griddly.spaces.action_space import MultiAgentActionSpace
from griddly.spaces.observation_space import MultiAgentObservationSpace
from griddly.typing import Action, ActionSpace, Observation, ObservationSpace
from griddly.util.rllib.environment.observer_episode_recorder import (
ObserverEpisodeRecorder,
)


class _RLlibEnvCache:
def __init__(self) -> None:
self.reset()

def reset(self) -> None:
self.action_space: Optional[Union[ActionSpace, MultiAgentActionSpace]] = None
self.observation_space: Optional[
Union[ObservationSpace, MultiAgentObservationSpace]
] = None


class _RLlibEnv(ABC):
def __init__(self, env_config: Dict[str, Any]) -> None:
self._rllib_cache = _RLlibEnvCache()

self._env = GymWrapper(**env_config, reset=False)

self.env_config = env_config

self.env_steps = 0
self._agent_recorders: Optional[
Union[ObserverEpisodeRecorder, List[ObserverEpisodeRecorder]]
] = None
self._global_recorder: Optional[ObserverEpisodeRecorder] = None

self._env_idx: Optional[int] = None
self._worker_idx: Optional[int] = None

self.video_initialized = False

self.record_video_config = env_config.get("record_video_config", None)

self.videos: List[Dict[str, Any]] = []

if self.record_video_config is not None:
self.video_frequency = self.record_video_config.get("frequency", 1000)
self.fps = self.record_video_config.get("fps", 10)
self.video_directory = os.path.realpath(
self.record_video_config.get("directory", ".")
)
self.include_global_video = self.record_video_config.get(
"include_global", True
)
self.include_agent_videos = self.record_video_config.get(
"include_agents", False
)
os.makedirs(self.video_directory, exist_ok=True)

self.record_actions = env_config.get("record_actions", False)

self.generate_valid_action_trees = env_config.get(
"generate_valid_action_trees", False
)
self._random_level_on_reset = env_config.get("random_level_on_reset", False)
level_generator_rllib_config = env_config.get("level_generator", None)

self._level_generator = None
if level_generator_rllib_config is not None:
level_generator_class = level_generator_rllib_config["class"]
level_generator_config = level_generator_rllib_config["config"]
self._level_generator = level_generator_class(level_generator_config)

self._env.enable_history(self.record_actions)

@property
def width(self) -> int:
assert self._env.observation_space.shape is not None
return self._env.observation_space.shape[0]

@property
def height(self) -> int:
assert self._env.observation_space.shape is not None
return self._env.observation_space.shape[1]

def _get_valid_action_trees(self) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
valid_action_trees = self._env.game.build_valid_action_trees()
if self._env.player_count == 1:
return valid_action_trees[0]
return valid_action_trees
Loading

0 comments on commit f1fde98

Please sign in to comment.