Skip to content

Commit

Permalink
Changing step to return truncate=True when curr_path_length equals ma…
Browse files Browse the repository at this point in the history
…x_path_length
  • Loading branch information
reginald-mclean committed Oct 5, 2023
1 parent 5fb485e commit e76738d
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,9 @@ def sawyer_observation_space(self):
def step(self, action):
assert len(action) == 4, f"Actions should be size 4, got {len(action)}"
self.set_xyz_action(action[:3])
if self.curr_path_length == self.max_path_length:
raise ValueError("You must reset the env manually once truncate==True")
self.do_simulation([action[-1], -action[-1]], n_frames=self.frame_skip)
if self.curr_path_length > self.max_path_length:
raise ValueError(
"Maximum path length allowed by the benchmark has been exceeded"
)
self.curr_path_length += 1

# Running the simulator can sometimes mess up site positions, so
Expand Down Expand Up @@ -496,12 +494,19 @@ def step(self, action):
dtype=np.float64,
)
reward, info = self.evaluate_state(self._last_stable_obs, action)
# step will never return a terminal if there is a success
# step will never return a terminate==True if there is a success
# but we can return truncate=True if the current path length == max path length


truncate = False
print(self.curr_path_length)
if self.curr_path_length == self.max_path_length:
truncate = True
return (
np.array(self._last_stable_obs, dtype=np.float64),
reward,
False,
False,
truncate,
info,
)

Expand Down

0 comments on commit e76738d

Please sign in to comment.