diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py similarity index 99% rename from stable_baselines3/dqn/prioritized_replay_buffer.py rename to stable_baselines3/common/prioritized_replay_buffer.py index d3f69f16a..852c77a49 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -19,7 +18,7 @@ class SumTree: """ def __init__(self, buffer_size: int) -> None: - self.nodes = np.zeros((2 * buffer_size - 1)) + self.nodes = np.zeros(2 * buffer_size - 1) self.data = np.zeros(buffer_size) self.size = buffer_size self.count = 0 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index e7d4a1c57..84aa474b3 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -7,6 +7,7 @@ from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -108,7 +109,9 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) -@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize( + "replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, PrioritizedReplayBuffer] +) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) def test_device_buffer(replay_buffer_cls, device): if device == "cuda" and not th.cuda.is_available(): @@ -119,6 +122,7 @@ def test_device_buffer(replay_buffer_cls, device): DictRolloutBuffer: DummyDictEnv, ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv, + PrioritizedReplayBuffer: DummyEnv, }[replay_buffer_cls] env = make_vec_env(env) @@ -139,7 +143,7 @@ def test_device_buffer(replay_buffer_cls, device): # Get data from the buffer if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: data = buffer.get(50) - elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: + elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer, PrioritizedReplayBuffer]: data = buffer.sample(50) # Check that all data are on the desired device diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..a7a30eed5 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -5,6 +5,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @@ -100,7 +101,8 @@ def test_n_critics(n_critics): model.learn(total_timesteps=200) -def test_dqn(): +@pytest.mark.parametrize("replay_buffer_class", [None, PrioritizedReplayBuffer]) +def test_dqn(replay_buffer_class): model = DQN( "MlpPolicy", "CartPole-v1", @@ -109,6 +111,7 @@ def test_dqn(): buffer_size=500, learning_rate=3e-4, verbose=1, + replay_buffer_class=replay_buffer_class, ) model.learn(total_timesteps=200)