Skip to content

Commit

Permalink
[RLlib] Fix 3 test cases that broke in move to revert PPO to old API …
Browse files Browse the repository at this point in the history
…stack. (ray-project#40788)
  • Loading branch information
sven1977 authored Oct 30, 2023
1 parent afdcdd2 commit 83785ab
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion rllib/examples/action_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def get_cli_args():
)
# We need to disable preprocessing of observations, because preprocessing
# would flatten the observation dict of the environment.
.experimental(_disable_preprocessor_api=True)
.experimental(
_enable_new_api_stack=True,
_disable_preprocessor_api=True,
)
.framework(args.framework)
.resources(
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
Expand Down
3 changes: 2 additions & 1 deletion rllib/examples/env/action_mask_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def step(self, action):
# Check whether action is valid.
if not self.valid_actions[action]:
raise ValueError(
f"Invalid action sent to env! " f"valid_actions={self.valid_actions}"
f"Invalid action ({action}) sent to env! "
f"valid_actions={self.valid_actions}"
)
obs, rew, done, truncated, info = super().step(action)
self._fix_action_mask(obs)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/learner/train_w_bc_finetune_w_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def train_ppo_agent_from_checkpointed_module(

config = (
PPOConfig()
.training()
.experimental(_enable_new_api_stack=True)
.rl_module(rl_module_spec=module_spec_from_ckpt)
.environment(GYM_ENV_NAME)
.debugging(seed=0)
Expand Down

0 comments on commit 83785ab

Please sign in to comment.