Skip to content

Commit

Permalink
enh: option to only flatten actions
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Nov 4, 2024
1 parent c257550 commit 4626817
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
48 changes: 45 additions & 3 deletions src/pystk2_gymnasium/stk_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
class PolarObservations(gym.ObservationWrapper):
"""Modifies position to polar positions
Angles are in radian
input: X right, Y up, Z forwards
output: (angle in the ZX plane, angle in the ZY plane, distance)
"""
Expand All @@ -43,16 +45,31 @@ def observation(self, obs):

for key in PolarObservations.KEYS:
v = obs[key]

is_tuple = False
if isinstance(v, tuple):
is_tuple = True
v = np.stack(v)
distance = np.linalg.norm(v, axis=1)
angle_zx = np.arctan2(v[:, 0], v[:, 2])
angle_zy = np.arctan2(v[:, 1], v[:, 2])
v[:, 0], v[:, 1], v[:, 2] = angle_zx, angle_zy, distance

if is_tuple:
obs[key] = tuple(x for x in v)
return obs


class ConstantSizedObservations(gym.ObservationWrapper):
def __init__(
self, env: gym.Env, *, state_items=5, state_karts=5, state_paths=5, **kwargs
self,
env: gym.Env,
*,
state_items=5,
state_karts=5,
state_paths=5,
add_mask=False,
**kwargs,
):
"""A simpler race environment with fixed width data
Expand Down Expand Up @@ -90,7 +107,19 @@ def __init__(
-float("inf"), float("inf"), shape=(self.state_karts, 3)
)

def make_tensor(self, state, name: str):
self.add_mask = add_mask
if add_mask:
space["paths_mask"] = spaces.Box(
0, 1, shape=(self.state_paths,), dtype=np.int8
)
space["items_mask"] = spaces.Box(
0, 1, shape=(self.state_items,), dtype=np.int8
)
space["karts_mask"] = spaces.Box(
0, 1, shape=(self.state_karts,), dtype=np.int8
)

def make_tensor(self, state, name: str, default_value=0):
value = state[name]
space = self.observation_space[name]

Expand All @@ -102,7 +131,9 @@ def make_tensor(self, state, name: str):
delta = space.shape[0] - value.shape[0]
if delta > 0:
shape = [delta] + list(space.shape[1:])
value = np.concatenate([value, np.zeros(shape, dtype=space.dtype)], axis=0)
value = np.concatenate(
[value, np.full(shape, default_value, dtype=space.dtype)], axis=0
)
elif delta < 0:
value = value[:delta]

Expand All @@ -115,6 +146,17 @@ def observation(self, state):
# Shallow copy
state = {**state}

# Add masks
def mask(length: int, size: int):
v = np.zeros((size,), dtype=np.int8)
v[:length] = 1
return v

if self.add_mask:
state["paths_mask"] = mask(len(state["paths_width"]), self.state_paths)
state["items_mask"] = mask(len(state["items_type"]), self.state_items)
state["karts_mask"] = mask(len(state["karts_position"]), self.state_karts)

# Ensures that the size of observations is constant
self.make_tensor(state, "paths_distance")
self.make_tensor(state, "paths_width")
Expand Down
38 changes: 23 additions & 15 deletions src/pystk2_gymnasium/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This module contains generic wrappers
"""
from copy import copy
from typing import Any, Callable, Dict, List, SupportsFloat, Tuple

import gymnasium as gym
Expand Down Expand Up @@ -99,30 +100,37 @@ def discrete(self, observation):
class FlattenerWrapper(ActionObservationWrapper):
"""Flattens actions and observations."""

def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env, flatten_observations=True):
super().__init__(env)

self.observation_flattener = SpaceFlattener(env.observation_space)
self.observation_space = self.observation_flattener.space
self.flatten_observations = flatten_observations
self.has_action = env.observation_space.get("action", None) is not None

self.action_flattener = SpaceFlattener(env.action_space)
self.action_space = self.action_flattener.space

# Adds action in the space
self.has_action = env.observation_space.get("action", None) is not None
if self.has_action:
if flatten_observations:
self.observation_flattener = SpaceFlattener(env.observation_space)
self.observation_space = self.observation_flattener.space
elif self.has_action:
self.observation_space = copy(env.observation_space)
self.observation_space["action"] = self.action_flattener.space

def observation(self, observation):
new_obs = {
"discrete": np.array(self.observation_flattener.discrete(observation)),
"continuous": np.concatenate(
[
observation[key].flatten()
for key in self.observation_flattener.continuous_keys
]
),
}
if self.flatten_observations:
new_obs = {
"discrete": np.array(self.observation_flattener.discrete(observation)),
"continuous": np.concatenate(
[
observation[key].flatten()
for key in self.observation_flattener.continuous_keys
]
),
}
elif self.has_action:
new_obs = {key: value for key, value in observation.items()}
else:
return observation

if self.has_action:
# Transforms from nested action to a flattened
Expand Down

0 comments on commit 4626817

Please sign in to comment.