diff --git a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py index ee8f9114e..a0e9c1f5d 100644 --- a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py +++ b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py @@ -458,11 +458,9 @@ def sawyer_observation_space(self): def step(self, action): assert len(action) == 4, f"Actions should be size 4, got {len(action)}" self.set_xyz_action(action[:3]) + if self.curr_path_length >= self.max_path_length: + raise ValueError("You must reset the env manually once truncate==True") self.do_simulation([action[-1], -action[-1]], n_frames=self.frame_skip) - if self.curr_path_length > self.max_path_length: - raise ValueError( - "Maximum path length allowed by the benchmark has been exceeded" - ) self.curr_path_length += 1 # Running the simulator can sometimes mess up site positions, so @@ -496,12 +494,16 @@ def step(self, action): dtype=np.float64, ) reward, info = self.evaluate_state(self._last_stable_obs, action) - # step will never return a terminal if there is a success + # step will never return a terminate==True if there is a success + # but we can return truncate=True if the current path length == max path length + truncate = False + if self.curr_path_length == self.max_path_length: + truncate = True return ( np.array(self._last_stable_obs, dtype=np.float64), reward, False, - False, + truncate, info, )