From b2ebb26f31f8f1c8334c25ec1f61ae27d84a29b0 Mon Sep 17 00:00:00 2001 From: Management Date: Fri, 4 Dec 2020 13:24:42 -0800 Subject: [PATCH 1/3] added check for max_episode_length --- src/garage/tf/algos/npo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/garage/tf/algos/npo.py b/src/garage/tf/algos/npo.py index a02a2df654..0ae7791916 100644 --- a/src/garage/tf/algos/npo.py +++ b/src/garage/tf/algos/npo.py @@ -102,6 +102,8 @@ def __init__(self, self.policy = policy self._scope = scope self.max_episode_length = env_spec.max_episode_length + if self.max_episode_length == None: + raise ValueError("max_episode_length must not be None") self._env_spec = env_spec self._baseline = baseline self._discount = discount From b6808d5b21079ca64620fbfd0b5d935943c4a455 Mon Sep 17 00:00:00 2001 From: Management Date: Fri, 4 Dec 2020 13:37:27 -0800 Subject: [PATCH 2/3] added test --- tests/garage/tf/algos/test_npo.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/garage/tf/algos/test_npo.py b/tests/garage/tf/algos/test_npo.py index abeea28de0..b889f22650 100644 --- a/tests/garage/tf/algos/test_npo.py +++ b/tests/garage/tf/algos/test_npo.py @@ -115,6 +115,22 @@ def test_npo_with_invalid_no_entropy_configuration(self): entropy_method='no_entropy', policy_ent_coeff=0.02, ) + + @pytest.mark.mujoco + def test_npo_with_invalid_max_episode_length(self): + """Test NPO with invalid max_episode_length.""" + with pytest.raises(ValueError): + env = normalize( + GymEnv('InvertedDoublePendulum-v2', max_episode_length=None)) + NPO( + env_spec=env.spec, + policy=self.policy, + baseline=self.baseline, + sampler=self.sampler, + discount=0.99, + gae_lambda=0.98, + policy_ent_coeff=0.0 + ) def teardown_method(self): self.env.close() From 396f3d6fc44960f18c22dec60059d7930af5aaff Mon Sep 17 00:00:00 2001 From: Management Date: Fri, 4 Dec 2020 17:05:28 -0800 Subject: [PATCH 3/3] moved if statement and changed double quotes to single quotes --- src/garage/tf/algos/npo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/garage/tf/algos/npo.py b/src/garage/tf/algos/npo.py index 0ae7791916..4527f4ce51 100644 --- a/src/garage/tf/algos/npo.py +++ b/src/garage/tf/algos/npo.py @@ -102,8 +102,6 @@ def __init__(self, self.policy = policy self._scope = scope self.max_episode_length = env_spec.max_episode_length - if self.max_episode_length == None: - raise ValueError("max_episode_length must not be None") self._env_spec = env_spec self._baseline = baseline self._discount = discount @@ -131,6 +129,9 @@ def __init__(self, if pg_loss not in ['vanilla', 'surrogate', 'surrogate_clip']: raise ValueError('Invalid pg_loss') + if self.max_episode_length == None: + raise ValueError('max_episode_length must not be None') + self._optimizer = make_optimizer(optimizer, **optimizer_args) self._lr_clip_range = float(lr_clip_range) self._max_kl_step = float(max_kl_step)