diff --git a/python/examples/AStar Search/main.py b/python/examples/AStar Search/main.py index 2963483a..ed9ae5d5 100644 --- a/python/examples/AStar Search/main.py +++ b/python/examples/AStar Search/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": # Uncommment to see normal actions (not rotated) being used diff --git a/python/examples/Custom Shaders/Global Lighting/main.py b/python/examples/Custom Shaders/Global Lighting/main.py index af080b57..a2300bc9 100644 --- a/python/examples/Custom Shaders/Global Lighting/main.py +++ b/python/examples/Custom Shaders/Global Lighting/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToFile, RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/examples/Custom Shaders/Health Bars/main.py b/python/examples/Custom Shaders/Health Bars/main.py index 2fb62dbf..885ab045 100644 --- a/python/examples/Custom Shaders/Health Bars/main.py +++ b/python/examples/Custom Shaders/Health Bars/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToFile, RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/examples/Custom Shaders/Object Lighting/main.py b/python/examples/Custom Shaders/Object Lighting/main.py index 41f78b2a..ec43a4dd 100644 --- a/python/examples/Custom Shaders/Object Lighting/main.py +++ b/python/examples/Custom Shaders/Object Lighting/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToFile, RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/examples/Level Design/main.py b/python/examples/Level Design/main.py index 492beea7..c90ab9d9 100644 --- a/python/examples/Level Design/main.py +++ b/python/examples/Level Design/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToFile -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/examples/Projectiles/main.py b/python/examples/Projectiles/main.py index e4972e23..048c1cea 100644 --- a/python/examples/Projectiles/main.py +++ b/python/examples/Projectiles/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/examples/Proximity/main.py b/python/examples/Proximity/main.py index b6d8e8a7..cc8ee2ca 100644 --- a/python/examples/Proximity/main.py +++ b/python/examples/Proximity/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/examples/Stochasticity/main.py b/python/examples/Stochasticity/main.py index ea521f7c..ed7d8096 100644 --- a/python/examples/Stochasticity/main.py +++ b/python/examples/Stochasticity/main.py @@ -1,6 +1,6 @@ from griddly import gd, gym from griddly.util.render_tools import RenderToVideo -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper if __name__ == "__main__": env = gym( diff --git a/python/griddly/__init__.py b/python/griddly/__init__.py index 2c4ee7e5..9c76a660 100644 --- a/python/griddly/__init__.py +++ b/python/griddly/__init__.py @@ -7,47 +7,6 @@ from griddly.gym import GymWrapperFactory -class GriddlyLoader: - def __init__(self) -> None: - module_path = os.path.dirname(os.path.realpath(__file__)) - self._image_path = os.path.join(module_path, "resources", "images") - self._shader_path = os.path.join(module_path, "resources", "shaders") - self._gdy_path = os.path.join(module_path, "resources", "games") - - self._gdy_reader = gd.GDYLoader( - self._gdy_path, self._image_path, self._shader_path - ) - - def get_full_path(self, gdy_path: str) -> str: - # Assume the file is relative first and if not, try to find it in the pre-defined games - fullpath = ( - gdy_path - if os.path.exists(gdy_path) - else os.path.join(self._gdy_path, gdy_path) - ) - # (for debugging only) look in parent directory resources because we might not have built the latest version - fullpath = ( - fullpath - if os.path.exists(fullpath) - else os.path.realpath( - os.path.join( - self._gdy_path + "../../../../../resources/games", gdy_path - ) - ) - ) - return fullpath - - def load(self, gdy_path: str) -> gd.GDY: - return self._gdy_reader.load(self.get_full_path(gdy_path)) - - def load_string(self, yaml_string: str) -> gd.GDY: - return self._gdy_reader.load_string(yaml_string) - - def load_gdy(self, gdy_path: str) -> Dict[str, Any]: - with open(self.get_full_path(gdy_path)) as gdy_file: - return yaml.load(gdy_file, Loader=yaml.SafeLoader) # type: ignore - - def preload_default_envs() -> None: module_path = os.path.dirname(os.path.realpath(__file__)) game_path = os.path.join(module_path, "resources", "games") diff --git a/python/griddly/gym.py b/python/griddly/gym.py index cfa039ef..a77ebc1b 100644 --- a/python/griddly/gym.py +++ b/python/griddly/gym.py @@ -8,7 +8,7 @@ from gymnasium.envs.registration import register from gymnasium.spaces import Discrete, MultiDiscrete -from griddly import GriddlyLoader +from griddly.loader import GriddlyLoader from griddly import gd as gd from griddly.spaces.action_space import MultiAgentActionSpace from griddly.spaces.observation_space import ( @@ -377,58 +377,67 @@ def step( # type: ignore """ player_id = 0 - reward: Union[List[int], int] - action_data = np.array(action, dtype=np.int32).reshape(1, -1) + if self.player_count == 1: + action = np.array(action, dtype=np.int32).reshape(1, -1, len(self.action_space_parts)) - if len(action_data) != self.player_count: - raise ValueError( - f"The supplied action is in the wrong format for this environment.\n\n" - f"A valid example: {self.action_space.sample()}" - ) + max_num_actions = 0 + for a in action: + if len(action) > max_num_actions: + max_num_actions = len(action) - # Simple agents executing single actions or multiple actions in a single time step - if self.player_count == 1: - reward, done, truncated, info = self._players[player_id].step_multi( - action_data, True - ) + action_data = np.zeros((self.player_count, max_num_actions, len(self.action_space_parts)), dtype=np.int32) - else: - processed_actions = [] - multi_action = False - - # Replace any None actions with a zero action - for a in action_data: - processed_action = ( - a - if a is not None - else np.zeros((len(self.action_space_parts)), dtype=np.int32) - ) - processed_actions.append(processed_action) - if len(processed_action.shape) > 1 and processed_action.shape[0] > 1: - multi_action = True - - if not self.has_avatar and multi_action: - # Multiple agents that can perform multiple actions in parallel - # Used in RTS games - reward = [] - for p in range(self.player_count): - player_action = processed_actions[p].reshape( - -1, len(self.action_space_parts) - ) - final = p == self.player_count - 1 - rew, done, truncated, info = self._players[p].step_multi( - player_action, final - ) - reward.append(rew) + for p in range(self.player_count): + for i, a in enumerate(action[p]): + action_data[p, i] = a - # Multiple agents executing actions in parallel - # Used in multi-agent environments - else: - action_data = np.array(processed_actions, dtype=np.int32) - action_data = action_data.reshape(self.player_count, -1) - reward, done, truncated, info = self.game.step_parallel(action_data) + + reward, done, truncated, info = self.game.step_parallel(action_data) + + # Simple agents executing single actions or multiple actions in a single time step + # if self.player_count == 1: + + # action_data = np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts)) + + # reward, done, truncated, info = self._players[player_id].step_multi( + # action_data, True + # ) + + # else: + + # processed_actions = [] + # multi_action = False + + # # Replace any None actions with a zero action + # for a in action: + # processed_action = ( + # np.array(a, dtype=np.int32).reshape(-1, len(self.action_space_parts)) + # if a is not None + # else np.zeros((1, len(self.action_space_parts)), dtype=np.int32) + # ) + # processed_actions.append(processed_action) + # if len(processed_action.shape) > 1 and processed_action.shape[0] > 1: + # multi_action = True + + # if not self.has_avatar and multi_action: + # # Multiple agents that can perform multiple actions in parallel + # # Used in RTS games + # reward = [] + # for p in range(self.player_count): + # final = p == self.player_count - 1 + # rew, done, truncated, info = self._players[p].step_multi( + # action, final + # ) + # reward.append(rew) + + # # Multiple agents executing actions in parallel + # # Used in multi-agent environments + # else: + # action_data = np.array(processed_actions, dtype=np.int32) + # action_data = action_data.reshape(self.player_count, -1) + # reward, done, truncated, info = self.game.step_parallel(action_data) # In the case where the environment is cloned, but no step has happened to replace the last obs, # we can do that here diff --git a/python/griddly/loader.py b/python/griddly/loader.py new file mode 100644 index 00000000..c89585e0 --- /dev/null +++ b/python/griddly/loader.py @@ -0,0 +1,46 @@ +import os +from typing import Any, Dict + +import yaml + +from griddly import gd + +class GriddlyLoader: + def __init__(self) -> None: + module_path = os.path.dirname(os.path.realpath(__file__)) + self._image_path = os.path.join(module_path, "resources", "images") + self._shader_path = os.path.join(module_path, "resources", "shaders") + self._gdy_path = os.path.join(module_path, "resources", "games") + + self._gdy_reader = gd.GDYLoader( + self._gdy_path, self._image_path, self._shader_path + ) + + def get_full_path(self, gdy_path: str) -> str: + # Assume the file is relative first and if not, try to find it in the pre-defined games + fullpath = ( + gdy_path + if os.path.exists(gdy_path) + else os.path.join(self._gdy_path, gdy_path) + ) + # (for debugging only) look in parent directory resources because we might not have built the latest version + fullpath = ( + fullpath + if os.path.exists(fullpath) + else os.path.realpath( + os.path.join( + self._gdy_path + "../../../../../resources/games", gdy_path + ) + ) + ) + return fullpath + + def load(self, gdy_path: str) -> gd.GDY: + return self._gdy_reader.load(self.get_full_path(gdy_path)) + + def load_string(self, yaml_string: str) -> gd.GDY: + return self._gdy_reader.load_string(yaml_string) + + def load_gdy(self, gdy_path: str) -> Dict[str, Any]: + with open(self.get_full_path(gdy_path)) as gdy_file: + return yaml.load(gdy_file, Loader=yaml.SafeLoader) # type: ignore \ No newline at end of file diff --git a/python/griddly/spaces/action_space.py b/python/griddly/spaces/action_space.py index b2477566..65b4dd58 100644 --- a/python/griddly/spaces/action_space.py +++ b/python/griddly/spaces/action_space.py @@ -1,10 +1,14 @@ -from typing import Any, List, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional, Union import numpy as np from gymnasium.spaces import Discrete, MultiDiscrete, Space -from griddly.gym import GymWrapper -from griddly.typing import Action +from griddly.typing import Action, ActionSpace + +if TYPE_CHECKING: + from griddly.gym import GymWrapper class MultiAgentActionSpace(Space[List[Action]], list): @@ -39,22 +43,13 @@ class ValidatedActionSpace(Space[Union[Action, List[Action]]]): def __init__( self, - action_space: Union[Discrete, MultiAgentActionSpace], + action_space: Space[Union[Action, List[Action]]], masking_wrapper: GymWrapper, ) -> None: self._masking_wrapper = masking_wrapper - shape = None - dtype = None - - if isinstance(action_space, Discrete) or isinstance( - action_space, MultiDiscrete - ): - shape = action_space.shape - dtype = action_space.dtype - elif isinstance(action_space, MultiAgentActionSpace): - shape = action_space[0].shape - dtype = action_space[0].dtype + shape = action_space.shape + dtype = action_space.dtype self.action_space = action_space diff --git a/python/griddly/typing.py b/python/griddly/typing.py index 7e4a6577..6357d7da 100644 --- a/python/griddly/typing.py +++ b/python/griddly/typing.py @@ -8,3 +8,5 @@ ObservationSpace = Space[Observation] ActionSpace = Space[Action] + + diff --git a/python/griddly/util/breakdown.py b/python/griddly/util/breakdown.py index 1eba29a3..38c5ab28 100644 --- a/python/griddly/util/breakdown.py +++ b/python/griddly/util/breakdown.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Any, Dict, Union, List +from typing import Any, Dict, Union import numpy as np import numpy.typing as npt import yaml -from griddly import GriddlyLoader, gd +from griddly.loader import GriddlyLoader +from griddly import gd from griddly.util.vector_visualization import Vector2RGB diff --git a/python/griddly/util/rllib/environment/base.py b/python/griddly/util/rllib/environment/base.py index ea1d1961..cf7fd2a2 100644 --- a/python/griddly/util/rllib/environment/base.py +++ b/python/griddly/util/rllib/environment/base.py @@ -1,17 +1,11 @@ import os -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union, Set - -import numpy as np -import numpy.typing as npt -from ray.rllib import MultiAgentEnv -from ray.rllib.utils.typing import MultiAgentDict +from abc import ABC +from typing import Any, Dict, List, Optional, Union from griddly.gym import GymWrapper from griddly.spaces.action_space import MultiAgentActionSpace from griddly.spaces.observation_space import MultiAgentObservationSpace -from griddly.typing import Action, ActionSpace, Observation, ObservationSpace +from griddly.typing import ActionSpace, ObservationSpace from griddly.util.rllib.environment.observer_episode_recorder import ( ObserverEpisodeRecorder, ) diff --git a/python/griddly/util/rllib/environment/single_agent.py b/python/griddly/util/rllib/environment/single_agent.py index 8bd81ebf..2a396ea6 100644 --- a/python/griddly/util/rllib/environment/single_agent.py +++ b/python/griddly/util/rllib/environment/single_agent.py @@ -3,10 +3,7 @@ import numpy as np import numpy.typing as npt -from griddly.gym import GymWrapper -from griddly.spaces.action_space import MultiAgentActionSpace -from griddly.spaces.observation_space import MultiAgentObservationSpace -from griddly.typing import Action, ActionSpace, Observation, ObservationSpace +from griddly.typing import Action, Observation from griddly.util.rllib.environment.base import _RLlibEnv from griddly.util.rllib.environment.observer_episode_recorder import \ ObserverEpisodeRecorder diff --git a/python/griddly/wrappers/render_wrapper.py b/python/griddly/wrappers/render_wrapper.py index 1a88f0b6..94f814e7 100644 --- a/python/griddly/wrappers/render_wrapper.py +++ b/python/griddly/wrappers/render_wrapper.py @@ -1,12 +1,18 @@ -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union import gymnasium as gym import numpy.typing as npt -from griddly.gym import GymWrapper +if TYPE_CHECKING: + from griddly.gym import GymWrapper class RenderWrapper(gym.Wrapper): + + env: GymWrapper + def __init__( self, env: GymWrapper, observer: Union[str, int] = 0, render_mode: str = "human" ) -> None: @@ -40,8 +46,6 @@ def __init__( self._observer = observer self._render_mode = render_mode - assert isinstance(self.env, GymWrapper) - if observer == "global": self.observation_space = env.global_observation_space elif isinstance(observer, int): @@ -55,7 +59,6 @@ def __init__( ) def render(self) -> Union[str, npt.NDArray]: # type: ignore - assert isinstance(self.env, GymWrapper) return self.env.render_observer(self._observer, self._render_mode) @property diff --git a/python/griddly/wrappers/valid_action_space_wrapper.py b/python/griddly/wrappers/valid_action_space_wrapper.py index 9a8cba28..9fd5bd2c 100644 --- a/python/griddly/wrappers/valid_action_space_wrapper.py +++ b/python/griddly/wrappers/valid_action_space_wrapper.py @@ -28,6 +28,8 @@ class ValidActionSpaceWrapper(gym.Wrapper): policy gradient methods. """ + env: GymWrapper + def __init__(self, env: GymWrapper) -> None: if env.action_space is None or env.observation_space is None: raise RuntimeError( @@ -36,11 +38,7 @@ def __init__(self, env: GymWrapper) -> None: super().__init__(env) - assert isinstance( - self.env, GymWrapper - ), "Invalid environment type. Can only wrap GymWrapper" - - self.action_space = self._override_action_space() + self.action_space = ValidatedActionSpace(self.action_space, self.env) def get_unit_location_mask( self, player_id: int, mask_type: str = "full" @@ -56,10 +54,6 @@ def get_unit_location_mask( assert player_id <= self.player_count, "Player does not exist." assert player_id > 0, "Player 0 is reserved for internal actions only." - assert isinstance( - self.env, GymWrapper - ), "Invalid environment type. Can only wrap GymWrapper" - if mask_type == "full": grid_mask = np.zeros((self.grid_width, self.grid_height)) for location, action_names in self.env.game.get_available_actions( @@ -92,10 +86,6 @@ def get_unit_action_mask( :return: """ - assert isinstance( - self.env, GymWrapper - ), "Invalid environment type. Can only wrap GymWrapper" - action_masks = {} for action_name, action_ids in self.env.game.get_available_action_ids( location, action_names @@ -111,17 +101,5 @@ def get_unit_action_mask( return action_masks - def _override_action_space(self) -> ValidatedActionSpace: - assert isinstance(self.action_space, gym.spaces.Discrete) or isinstance( - self.action_space, MultiAgentActionSpace - ), "Invalid action space type. Can only wrap Discrete or MultiAgentActionSpace" - assert isinstance( - self.env, GymWrapper - ), "Invalid environment type. Can only wrap GymWrapper" - return ValidatedActionSpace(self.action_space, self.env) - def clone(self) -> ValidActionSpaceWrapper: - assert isinstance( - self.env, GymWrapper - ), "Invalid environment type. Can only wrap GymWrapper" return ValidActionSpaceWrapper(self.env.clone()) diff --git a/python/pyproject.toml b/python/pyproject.toml index 09109344..e342359b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -23,7 +23,7 @@ show_error_codes = true no_implicit_optional = true warn_return_any = true warn_unused_ignores = true -exclude = ["docs", "tests", "examples", "scratchpad", "tools"] +exclude = ["docs", "tests", "examples", "scratchpad", "tools", "build"] [tool.poetry.dependencies] python = "^3.8, <3.12" diff --git a/python/tests/cat_test.py b/python/tests/cat_test.py index 696b0edd..1fc2f6e0 100644 --- a/python/tests/cat_test.py +++ b/python/tests/cat_test.py @@ -1,31 +1,20 @@ -import gymnasium as gym -import pytest +from griddly import gd +from griddly.gym import GymWrapper -from griddly import GymWrapperFactory, gd - -@pytest.fixture -def test_name(request): - return request.node.name - - -def build_test_env(test_name, yaml_file): - wrapper_factory = GymWrapperFactory() - - wrapper_factory.build_gym_from_yaml( - test_name, - yaml_file, +def build_test_env(yaml_file): + env = GymWrapper( + yaml_file=yaml_file, global_observer_type=gd.ObserverType.VECTOR, player_observer_type=gd.ObserverType.VECTOR, ) - env = gym.make(f"GDY-{test_name}-v0") env.reset() return env -def test_CAT_depth_1(test_name): - env = build_test_env(test_name, "tests/gdy/test_CAT_depth_1.yaml") +def test_CAT_depth_1(): + env = build_test_env("tests/gdy/test_CAT_depth_1.yaml") valid_action_trees = env.game.build_valid_action_trees() @@ -33,8 +22,8 @@ def test_CAT_depth_1(test_name): assert set(valid_action_trees[0].keys()) == {0, 1, 2, 3} -def test_CAT_depth_2(test_name): - env = build_test_env(test_name, "tests/gdy/test_CAT_depth_2.yaml") +def test_CAT_depth_2(): + env = build_test_env("tests/gdy/test_CAT_depth_2.yaml") valid_action_trees = env.game.build_valid_action_trees() @@ -45,8 +34,8 @@ def test_CAT_depth_2(test_name): assert set(valid_action_trees[0][1].keys()) == {0, 4} -def test_CAT_depth_3(test_name): - env = build_test_env(test_name, "tests/gdy/test_CAT_depth_3.yaml") +def test_CAT_depth_3(): + env = build_test_env("tests/gdy/test_CAT_depth_3.yaml") valid_action_trees = env.game.build_valid_action_trees() @@ -56,8 +45,8 @@ def test_CAT_depth_3(test_name): assert set(valid_action_trees[0][1][1].keys()) == {0, 1, 2, 3} -def test_CAT_depth_4(test_name): - env = build_test_env(test_name, "tests/gdy/test_CAT_depth_4.yaml") +def test_CAT_depth_4(): + env = build_test_env("tests/gdy/test_CAT_depth_4.yaml") valid_action_trees = env.game.build_valid_action_trees() @@ -70,8 +59,8 @@ def test_CAT_depth_4(test_name): assert set(valid_action_trees[0][1][1][1].keys()) == {0, 4} -def test_CAT_depth_4_2_players(test_name): - env = build_test_env(test_name, "tests/gdy/test_CAT_depth_4_2_players.yaml") +def test_CAT_depth_4_2_players(): + env = build_test_env("tests/gdy/test_CAT_depth_4_2_players.yaml") valid_action_trees = env.game.build_valid_action_trees() diff --git a/python/tests/egg_test.py b/python/tests/egg_test.py index b9944a80..2d72cfab 100644 --- a/python/tests/egg_test.py +++ b/python/tests/egg_test.py @@ -4,7 +4,7 @@ from griddly import gd from griddly.gym import GymWrapperFactory from griddly.util.environment_generator_generator import EnvironmentGeneratorGenerator -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper @pytest.fixture diff --git a/python/tests/partial_observability_test.py b/python/tests/partial_observability_test.py index 0a6c2811..343db316 100644 --- a/python/tests/partial_observability_test.py +++ b/python/tests/partial_observability_test.py @@ -3,7 +3,7 @@ from griddly import gd from griddly.gym import GymWrapperFactory -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper @pytest.fixture @@ -40,7 +40,7 @@ def test_partial_observability_0_1(test_name): obs, reward, done, truncated, info = env.step([0, 0]) player1_obs = obs[0] - player2_obs = obs[1] + player2_obs = obs[1] assert env.player_observation_space[0].shape == (1, 3, 3) assert env.player_observation_space[1].shape == (1, 3, 3) diff --git a/python/tests/random_seed_test.py b/python/tests/random_seed_test.py index 4a2b0dd9..c78c2861 100644 --- a/python/tests/random_seed_test.py +++ b/python/tests/random_seed_test.py @@ -2,7 +2,7 @@ from griddly import gd from griddly.gym import GymWrapper -from griddly.wrappers import RenderWrapper +from griddly.wrappers.render_wrapper import RenderWrapper def create_env(seed): diff --git a/python/tests/rllib_test.py b/python/tests/rllib_test.py index 25e624f5..64eb0d33 100644 --- a/python/tests/rllib_test.py +++ b/python/tests/rllib_test.py @@ -12,7 +12,8 @@ from griddly import gd from griddly.util.rllib.callbacks import VideoCallbacks -from griddly.util.rllib.environment.base import RLlibEnv, RLlibMultiAgentWrapper +from griddly.util.rllib.environment.single_agent import RLlibEnv +from griddly.util.rllib.environment.multi_agent import RLlibMultiAgentWrapper def count_videos(video_dir): diff --git a/python/tests/valid_action_space_wrapper_test.py b/python/tests/valid_action_space_wrapper_test.py index 1c06c971..c23a972e 100644 --- a/python/tests/valid_action_space_wrapper_test.py +++ b/python/tests/valid_action_space_wrapper_test.py @@ -3,7 +3,7 @@ import pytest from griddly import GymWrapperFactory, gd -from griddly.wrappers import ValidActionSpaceWrapper +from griddly.wrappers.valid_action_space_wrapper import ValidActionSpaceWrapper @pytest.fixture diff --git a/python/tools/package_resources.py b/python/tools/package_resources.py index 2d51c799..5ce7ef50 100644 --- a/python/tools/package_resources.py +++ b/python/tools/package_resources.py @@ -3,9 +3,10 @@ import shutil from pathlib import Path from sys import platform +from typing import List -def get_libs(root_path: Path, config: str = "Debug") -> list[str]: +def get_libs(root_path: Path, config: str = "Debug") -> List[str]: libs_path = Path.joinpath(root_path.parent, f"{config}/bin").resolve() libs_to_copy: list[str] = []