Skip to content

Commit

Permalink
Handling multi-dimensional action spaces (#971)
Browse files Browse the repository at this point in the history
* Handle non 1D action shape

* Revert changes of observation (out of the scope of this PR)

* Apply changes  to DictReplayBuffer

* Update tests

* Rollout buffer n-D actions space handling

* Remove error when non 1D action space

* ActorCriticPolicy return action with the proper shape

* remove useless reshape

* Update changelog

* Add tests

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
qgallouedec and araffin authored Aug 6, 2022
1 parent 6ce33f5 commit c4f54fc
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Added multidimensional action space support (@qgallouedec)

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 5 additions & 4 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ def add(
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)

# Same, for actions
if isinstance(self.action_space, spaces.Discrete):
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))

# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
Expand Down Expand Up @@ -433,6 +432,9 @@ def add(
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)

# Same reshape, for actions
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
Expand Down Expand Up @@ -586,8 +588,7 @@ def add(
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()

# Same reshape, for actions
if isinstance(self.action_space, spaces.Discrete):
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,6 @@ def make_proba_distribution(
dist_kwargs = {}

if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def predict(

with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
# Convert to numpy
actions = actions.cpu().numpy()
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)

if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
Expand Down Expand Up @@ -592,6 +592,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob

def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
Expand Down
36 changes: 33 additions & 3 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ def step(self, action):
return self.observation_space.sample(), 0.0, False, {}


class DummyMultidimensionalAction(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)

def reset(self):
return self.observation_space.sample()

def step(self, action):
return self.observation_space.sample(), 0.0, False, {}


@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
def test_identity_spaces(model_class, env):
Expand All @@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env):


@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
def test_action_spaces(model_class, env):
kwargs = {}
if model_class in [SAC, DDPG, TD3]:
supported_action_space = env == "Pendulum-v1"
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
kwargs["learning_starts"] = 2
kwargs["train_freq"] = 32
elif model_class == DQN:
supported_action_space = env == "CartPole-v1"
elif model_class in [A2C, PPO]:
supported_action_space = True
kwargs["n_steps"] = 64

if supported_action_space:
model_class("MlpPolicy", env)
model = model_class("MlpPolicy", env, **kwargs)
if isinstance(env, DummyMultidimensionalAction):
model.learn(64)
else:
with pytest.raises(AssertionError):
model_class("MlpPolicy", env)


def test_sde_multi_dim():
SAC(
"MlpPolicy",
DummyMultidimensionalAction(),
learning_starts=10,
use_sde=True,
sde_sample_freq=2,
use_sde_at_warmup=True,
).learn(20)


@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", ["Taxi-v3"])
def test_discrete_obs_space(model_class, env):
Expand Down

0 comments on commit c4f54fc

Please sign in to comment.