From 0d17e18cb4e51207e95de4aae70c17dcc51ed84d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 10 Jun 2024 11:02:38 +0200 Subject: [PATCH] is_success consistency --- panda_gym/envs/tasks/flip.py | 2 +- panda_gym/envs/tasks/pick_and_place.py | 2 +- panda_gym/envs/tasks/push.py | 2 +- panda_gym/envs/tasks/reach.py | 2 +- panda_gym/envs/tasks/slide.py | 2 +- panda_gym/envs/tasks/stack.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/panda_gym/envs/tasks/flip.py b/panda_gym/envs/tasks/flip.py index 1df64ceb..70723815 100644 --- a/panda_gym/envs/tasks/flip.py +++ b/panda_gym/envs/tasks/flip.py @@ -78,7 +78,7 @@ def _sample_object(self) -> Tuple[np.ndarray, np.ndarray]: object_rotation = np.zeros(3) return object_position, object_rotation - def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray: + def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray: d = angle_distance(achieved_goal, desired_goal) return np.array(d < self.distance_threshold, dtype=bool) diff --git a/panda_gym/envs/tasks/pick_and_place.py b/panda_gym/envs/tasks/pick_and_place.py index 47ef2c31..4ac55f1c 100644 --- a/panda_gym/envs/tasks/pick_and_place.py +++ b/panda_gym/envs/tasks/pick_and_place.py @@ -83,7 +83,7 @@ def _sample_object(self) -> np.ndarray: object_position += noise return object_position - def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray: + def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray: d = distance(achieved_goal, desired_goal) return np.array(d < self.distance_threshold, dtype=bool) diff --git a/panda_gym/envs/tasks/push.py b/panda_gym/envs/tasks/push.py index 364fd9d2..e3b53af5 100644 --- a/panda_gym/envs/tasks/push.py +++ b/panda_gym/envs/tasks/push.py @@ -85,7 +85,7 @@ def _sample_object(self) -> np.ndarray: object_position += noise return object_position - def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray: + def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray: d = distance(achieved_goal, desired_goal) return np.array(d < self.distance_threshold, dtype=bool) diff --git a/panda_gym/envs/tasks/reach.py b/panda_gym/envs/tasks/reach.py index 9949dcf2..abca7d47 100644 --- a/panda_gym/envs/tasks/reach.py +++ b/panda_gym/envs/tasks/reach.py @@ -52,7 +52,7 @@ def _sample_goal(self) -> np.ndarray: goal = self.np_random.uniform(self.goal_range_low, self.goal_range_high) return goal - def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray: + def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray: d = distance(achieved_goal, desired_goal) return np.array(d < self.distance_threshold, dtype=bool) diff --git a/panda_gym/envs/tasks/slide.py b/panda_gym/envs/tasks/slide.py index 04911360..047fb215 100644 --- a/panda_gym/envs/tasks/slide.py +++ b/panda_gym/envs/tasks/slide.py @@ -89,7 +89,7 @@ def _sample_object(self) -> np.ndarray: object_position += noise return object_position - def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray: + def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray: d = distance(achieved_goal, desired_goal) return np.array(d < self.distance_threshold, dtype=bool) diff --git a/panda_gym/envs/tasks/stack.py b/panda_gym/envs/tasks/stack.py index cde63cc1..919adabc 100644 --- a/panda_gym/envs/tasks/stack.py +++ b/panda_gym/envs/tasks/stack.py @@ -117,7 +117,7 @@ def _sample_objects(self) -> Tuple[np.ndarray, np.ndarray]: # if distance(object1_position, object2_position) > 0.1: return object1_position, object2_position - def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray: + def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray: # must be vectorized !! d = distance(achieved_goal, desired_goal) return np.array((d < self.distance_threshold), dtype=bool)