Skip to content

Commit

Permalink
BitFlippingEnv argument check and docs clarification (#1698)
Browse files Browse the repository at this point in the history
* made change, not tested yet

* add back _obs_space with note on purpose

* match formatting

* update documentation
  • Loading branch information
kylesayrs authored Sep 27, 2023
1 parent 2ca94cb commit fab6cb3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Bug Fixes:
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
121 changes: 70 additions & 51 deletions stable_baselines3/common/envs/bit_flipping_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ class BitFlippingEnv(Env):
Simple bit flipping env, useful to test HER.
The goal is to flip all the bits to get a vector of ones.
In the continuous variant, if the ith action component has a value > 0,
then the ith bit will be flipped.
then the ith bit will be flipped. Uses a ``MultiBinary`` observation space
by default.
:param n_bits: Number of bits to flip
:param continuous: Whether to use the continuous actions version or not,
by default, it uses the discrete one
:param max_steps: Max number of steps, by default, equal to n_bits
:param discrete_obs_space: Whether to use the discrete observation
version or not, by default, it uses the ``MultiBinary`` one
:param image_obs_space: Use image as input instead of the ``MultiBinary`` one.
version or not, ie a one-hot encoding of all possible states
:param image_obs_space: Whether to use an image observation version
or not, ie a greyscale image of the state
:param channel_first: Whether to use channel-first or last image.
"""

Expand All @@ -44,52 +46,11 @@ def __init__(
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
# here, it is a special where they are equal
if discrete_obs_space:
# In the discrete case, the agent act on the binary
# representation of the observation
self.observation_space = spaces.Dict(
{
"observation": spaces.Discrete(2**n_bits),
"achieved_goal": spaces.Discrete(2**n_bits),
"desired_goal": spaces.Discrete(2**n_bits),
}
)
elif image_obs_space:
# When using image as input,
# one image contains the bits 0 -> 0, 1 -> 255
# and the rest is filled with zeros
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"achieved_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"desired_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
}
)
else:
self.observation_space = spaces.Dict(
{
"observation": spaces.MultiBinary(n_bits),
"achieved_goal": spaces.MultiBinary(n_bits),
"desired_goal": spaces.MultiBinary(n_bits),
}
)

self.obs_space = spaces.MultiBinary(n_bits)
# observation space for observations given to the model
self.observation_space = self._make_observation_space(discrete_obs_space, image_obs_space, n_bits)
# observation space used to update internal state
self._obs_space = spaces.MultiBinary(n_bits)

if continuous:
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
Expand All @@ -105,7 +66,7 @@ def __init__(
self.current_step = 0

def seed(self, seed: int) -> None:
self.obs_space.seed(seed)
self._obs_space.seed(seed)

def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
"""
Expand Down Expand Up @@ -144,6 +105,64 @@ def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int)
bit_vector = np.array(state).reshape(batch_size, -1)
return bit_vector

def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: bool, n_bits: int) -> spaces.Dict:
"""
Helper to create observation space
:param discrete_obs_space: Whether to use the discrete observation version
:param image_obs_space: Whether to use the image observation version
:param n_bits: The number of bits used to represent the state
:return: the environment observation space
"""
if discrete_obs_space and image_obs_space:
raise ValueError("Cannot use both discrete and image observation spaces")

if discrete_obs_space:
# In the discrete case, the agent act on the binary
# representation of the observation
return spaces.Dict(
{
"observation": spaces.Discrete(2**n_bits),
"achieved_goal": spaces.Discrete(2**n_bits),
"desired_goal": spaces.Discrete(2**n_bits),
}
)

if image_obs_space:
# When using image as input,
# one image contains the bits 0 -> 0, 1 -> 255
# and the rest is filled with zeros
return spaces.Dict(
{
"observation": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"achieved_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"desired_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
}
)

return spaces.Dict(
{
"observation": spaces.MultiBinary(n_bits),
"achieved_goal": spaces.MultiBinary(n_bits),
"desired_goal": spaces.MultiBinary(n_bits),
}
)

def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
"""
Helper to create the observation.
Expand All @@ -162,9 +181,9 @@ def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict] = None
) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]:
if seed is not None:
self.obs_space.seed(seed)
self._obs_space.seed(seed)
self.current_step = 0
self.state = self.obs_space.sample()
self.state = self._obs_space.sample()
return self._get_obs(), {}

def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
Expand Down

0 comments on commit fab6cb3

Please sign in to comment.