From c4f54fcf047d7bf425fb6b88a3c8ed23fe375f9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 6 Aug 2022 14:19:20 +0200 Subject: [PATCH] Handling multi-dimensional action spaces (#971) * Handle non 1D action shape * Revert changes of observation (out of the scope of this PR) * Apply changes to DictReplayBuffer * Update tests * Rollout buffer n-D actions space handling * Remove error when non 1D action space * ActorCriticPolicy return action with the proper shape * remove useless reshape * Update changelog * Add tests Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + stable_baselines3/common/buffers.py | 9 +++--- stable_baselines3/common/distributions.py | 1 - stable_baselines3/common/policies.py | 5 ++-- tests/test_spaces.py | 36 +++++++++++++++++++++-- 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b7f27abf1..e9d4b7838 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -19,6 +19,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. +- Added multidimensional action space support (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 5ed9b4c96..0eb26515b 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -247,8 +247,7 @@ def add( next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape) # Same, for actions - if isinstance(self.action_space, spaces.Discrete): - action = action.reshape((self.n_envs, self.action_dim)) + action = action.reshape((self.n_envs, self.action_dim)) # Copy to avoid modification by reference self.observations[self.pos] = np.array(obs).copy() @@ -433,6 +432,9 @@ def add( if isinstance(self.observation_space, spaces.Discrete): obs = obs.reshape((self.n_envs,) + self.obs_shape) + # Same reshape, for actions + action = action.reshape((self.n_envs, self.action_dim)) + self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() @@ -586,8 +588,7 @@ def add( self.next_observations[key][self.pos] = np.array(next_obs[key]).copy() # Same reshape, for actions - if isinstance(self.action_space, spaces.Discrete): - action = action.reshape((self.n_envs, self.action_dim)) + action = action.reshape((self.n_envs, self.action_dim)) self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 524775150..fc4862511 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -658,7 +658,6 @@ def make_proba_distribution( dist_kwargs = {} if isinstance(action_space, spaces.Box): - assert len(action_space.shape) == 1, "Error: the action space must be a vector" cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index a88fad602..3809b61f7 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -336,8 +336,8 @@ def predict( with th.no_grad(): actions = self._predict(observation, deterministic=deterministic) - # Convert to numpy - actions = actions.cpu().numpy() + # Convert to numpy, and reshape to the original action shape + actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape) if isinstance(self.action_space, gym.spaces.Box): if self.squash_output: @@ -592,6 +592,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1,) + self.action_space.shape) return actions, values, log_prob def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: diff --git a/tests/test_spaces.py b/tests/test_spaces.py index b75404288..0696492d6 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -33,6 +33,19 @@ def step(self, action): return self.observation_space.sample(), 0.0, False, {} +class DummyMultidimensionalAction(gym.Env): + def __init__(self): + super().__init__() + self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) def test_identity_spaces(model_class, env): @@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env): @pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3]) -@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) +@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()]) def test_action_spaces(model_class, env): + kwargs = {} if model_class in [SAC, DDPG, TD3]: - supported_action_space = env == "Pendulum-v1" + supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction) + kwargs["learning_starts"] = 2 + kwargs["train_freq"] = 32 elif model_class == DQN: supported_action_space = env == "CartPole-v1" elif model_class in [A2C, PPO]: supported_action_space = True + kwargs["n_steps"] = 64 if supported_action_space: - model_class("MlpPolicy", env) + model = model_class("MlpPolicy", env, **kwargs) + if isinstance(env, DummyMultidimensionalAction): + model.learn(64) else: with pytest.raises(AssertionError): model_class("MlpPolicy", env) +def test_sde_multi_dim(): + SAC( + "MlpPolicy", + DummyMultidimensionalAction(), + learning_starts=10, + use_sde=True, + sde_sample_freq=2, + use_sde_at_warmup=True, + ).learn(20) + + @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) @pytest.mark.parametrize("env", ["Taxi-v3"]) def test_discrete_obs_space(model_class, env):