Skip to content

Commit

Permalink
Move to common and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Sep 29, 2023
1 parent fb33732 commit f984e5c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -19,7 +18,7 @@ class SumTree:
"""

def __init__(self, buffer_size: int) -> None:
self.nodes = np.zeros((2 * buffer_size - 1))
self.nodes = np.zeros(2 * buffer_size - 1)
self.data = np.zeros(buffer_size)
self.size = buffer_size
self.count = 0
Expand Down
8 changes: 6 additions & 2 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
Expand Down Expand Up @@ -108,7 +109,9 @@ def test_replay_buffer_normalization(replay_buffer_cls):
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)


@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
@pytest.mark.parametrize(
"replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, PrioritizedReplayBuffer]
)
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_device_buffer(replay_buffer_cls, device):
if device == "cuda" and not th.cuda.is_available():
Expand All @@ -119,6 +122,7 @@ def test_device_buffer(replay_buffer_cls, device):
DictRolloutBuffer: DummyDictEnv,
ReplayBuffer: DummyEnv,
DictReplayBuffer: DummyDictEnv,
PrioritizedReplayBuffer: DummyEnv,
}[replay_buffer_cls]
env = make_vec_env(env)

Expand All @@ -139,7 +143,7 @@ def test_device_buffer(replay_buffer_cls, device):
# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer, PrioritizedReplayBuffer]:
data = buffer.sample(50)

# Check that all data are on the desired device
Expand Down
5 changes: 4 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer

normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))

Expand Down Expand Up @@ -100,7 +101,8 @@ def test_n_critics(n_critics):
model.learn(total_timesteps=200)


def test_dqn():
@pytest.mark.parametrize("replay_buffer_class", [None, PrioritizedReplayBuffer])
def test_dqn(replay_buffer_class):
model = DQN(
"MlpPolicy",
"CartPole-v1",
Expand All @@ -109,6 +111,7 @@ def test_dqn():
buffer_size=500,
learning_rate=3e-4,
verbose=1,
replay_buffer_class=replay_buffer_class,
)
model.learn(total_timesteps=200)

Expand Down

0 comments on commit f984e5c

Please sign in to comment.