Skip to content

Commit

Permalink
test: One On- and one Off-Policy algorithm (A2C and SAC respectively)…
Browse files Browse the repository at this point in the history
…, with settings to speed up testing
  • Loading branch information
iwishiwasaneagle committed Oct 4, 2023
1 parent ce42b9a commit b27adf6
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,25 @@ def _on_step(self) -> bool:
return self.callback_false_value


@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
@pytest.mark.parametrize(
"model_class,model_kwargs",
[
(A2C, dict(n_steps=1, stats_window_size=1)),
(
SAC,
dict(
learning_starts=1,
buffer_size=1,
batch_size=1,
),
),
],
)
@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)])
def test_callbacks_can_cancel_runs(model_class, callback_false_value):
def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value):
assert not callback_false_value # Sanity check to ensure parametrized values are valid
env_id = select_env(model_class)
model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]))
model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2]))
eval_callback = EvalCallback(
gym.make(env_id),
callback_after_eval=AlwaysFailCallback(callback_false_value=callback_false_value),
Expand Down

0 comments on commit b27adf6

Please sign in to comment.