diff --git a/src/pystk2_gymnasium/stk_wrappers.py b/src/pystk2_gymnasium/stk_wrappers.py index 54c2158..ce9d580 100644 --- a/src/pystk2_gymnasium/stk_wrappers.py +++ b/src/pystk2_gymnasium/stk_wrappers.py @@ -18,6 +18,8 @@ class PolarObservations(gym.ObservationWrapper): """Modifies position to polar positions + Angles are in radian + input: X right, Y up, Z forwards output: (angle in the ZX plane, angle in the ZY plane, distance) """ @@ -43,16 +45,31 @@ def observation(self, obs): for key in PolarObservations.KEYS: v = obs[key] + + is_tuple = False + if isinstance(v, tuple): + is_tuple = True + v = np.stack(v) distance = np.linalg.norm(v, axis=1) angle_zx = np.arctan2(v[:, 0], v[:, 2]) angle_zy = np.arctan2(v[:, 1], v[:, 2]) v[:, 0], v[:, 1], v[:, 2] = angle_zx, angle_zy, distance + + if is_tuple: + obs[key] = tuple(x for x in v) return obs class ConstantSizedObservations(gym.ObservationWrapper): def __init__( - self, env: gym.Env, *, state_items=5, state_karts=5, state_paths=5, **kwargs + self, + env: gym.Env, + *, + state_items=5, + state_karts=5, + state_paths=5, + add_mask=False, + **kwargs, ): """A simpler race environment with fixed width data @@ -90,7 +107,19 @@ def __init__( -float("inf"), float("inf"), shape=(self.state_karts, 3) ) - def make_tensor(self, state, name: str): + self.add_mask = add_mask + if add_mask: + space["paths_mask"] = spaces.Box( + 0, 1, shape=(self.state_paths,), dtype=np.int8 + ) + space["items_mask"] = spaces.Box( + 0, 1, shape=(self.state_items,), dtype=np.int8 + ) + space["karts_mask"] = spaces.Box( + 0, 1, shape=(self.state_karts,), dtype=np.int8 + ) + + def make_tensor(self, state, name: str, default_value=0): value = state[name] space = self.observation_space[name] @@ -102,7 +131,9 @@ def make_tensor(self, state, name: str): delta = space.shape[0] - value.shape[0] if delta > 0: shape = [delta] + list(space.shape[1:]) - value = np.concatenate([value, np.zeros(shape, dtype=space.dtype)], axis=0) + value = np.concatenate( + [value, np.full(shape, default_value, dtype=space.dtype)], axis=0 + ) elif delta < 0: value = value[:delta] @@ -115,6 +146,17 @@ def observation(self, state): # Shallow copy state = {**state} + # Add masks + def mask(length: int, size: int): + v = np.zeros((size,), dtype=np.int8) + v[:length] = 1 + return v + + if self.add_mask: + state["paths_mask"] = mask(len(state["paths_width"]), self.state_paths) + state["items_mask"] = mask(len(state["items_type"]), self.state_items) + state["karts_mask"] = mask(len(state["karts_position"]), self.state_karts) + # Ensures that the size of observations is constant self.make_tensor(state, "paths_distance") self.make_tensor(state, "paths_width") diff --git a/src/pystk2_gymnasium/wrappers.py b/src/pystk2_gymnasium/wrappers.py index 47d471d..868a27e 100644 --- a/src/pystk2_gymnasium/wrappers.py +++ b/src/pystk2_gymnasium/wrappers.py @@ -1,6 +1,7 @@ """ This module contains generic wrappers """ +from copy import copy from typing import Any, Callable, Dict, List, SupportsFloat, Tuple import gymnasium as gym @@ -99,30 +100,37 @@ def discrete(self, observation): class FlattenerWrapper(ActionObservationWrapper): """Flattens actions and observations.""" - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env, flatten_observations=True): super().__init__(env) - self.observation_flattener = SpaceFlattener(env.observation_space) - self.observation_space = self.observation_flattener.space + self.flatten_observations = flatten_observations + self.has_action = env.observation_space.get("action", None) is not None self.action_flattener = SpaceFlattener(env.action_space) self.action_space = self.action_flattener.space - # Adds action in the space - self.has_action = env.observation_space.get("action", None) is not None - if self.has_action: + if flatten_observations: + self.observation_flattener = SpaceFlattener(env.observation_space) + self.observation_space = self.observation_flattener.space + elif self.has_action: + self.observation_space = copy(env.observation_space) self.observation_space["action"] = self.action_flattener.space def observation(self, observation): - new_obs = { - "discrete": np.array(self.observation_flattener.discrete(observation)), - "continuous": np.concatenate( - [ - observation[key].flatten() - for key in self.observation_flattener.continuous_keys - ] - ), - } + if self.flatten_observations: + new_obs = { + "discrete": np.array(self.observation_flattener.discrete(observation)), + "continuous": np.concatenate( + [ + observation[key].flatten() + for key in self.observation_flattener.continuous_keys + ] + ), + } + elif self.has_action: + new_obs = {key: value for key, value in observation.items()} + else: + return observation if self.has_action: # Transforms from nested action to a flattened